import socket
import struct
import json
import threading
import argparse
import random


def set_seed(seed: int):
    random.seed(seed)
    try:
        import numpy as np
        np.random.seed(seed)
    except ImportError:
        pass
    try:
        import torch
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except ImportError:
        pass

def recvall(sock, n):
    data = b""
    while len(data) < n:
        chunk = sock.recv(n - len(data))
        if not chunk:
            raise ConnectionError("client disconnected")
        data += chunk
    return data

def recv_json(sock):
    header = recvall(sock, 4)
    (length,) = struct.unpack(">I", header)
    payload = recvall(sock, length)
    return json.loads(payload.decode("utf-8"))

def send_json(sock, obj):
    data = json.dumps(obj, ensure_ascii=False).encode("utf-8")
    sock.sendall(struct.pack(">I", len(data)) + data)

_USE_DEFAULT_LLM = False
_MODEL = None
_TOKENIZER = None

def init_default_llm():
    global _MODEL, _TOKENIZER

    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM

    MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct"

    _TOKENIZER = AutoTokenizer.from_pretrained(
        MODEL_ID,
        trust_remote_code=True,
    )
    _MODEL = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        device_map="auto",
        load_in_8bit=True,
        torch_dtype=torch.float16,
        trust_remote_code=True,
    )

def default_generate(prompt: str) -> str:
    import torch

    inputs = _TOKENIZER(prompt, return_tensors="pt").to(_MODEL.device)

    with torch.inference_mode():
        out_ids = _MODEL.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=False,
            eos_token_id=_TOKENIZER.eos_token_id,
        )

    gen_ids = out_ids[0][inputs["input_ids"].shape[-1]:]
    return _TOKENIZER.decode(gen_ids, skip_special_tokens=True)

from ours_online import (
    get_initial_code,
    get_repaired_statements,
    get_additional_code,
    get_repaired_code_on_failure,
    init_infilling_llm,
)

def handle_client(conn, addr):
    try:
        msg = recv_json(conn)
        if msg.get("type") != "hello":
            return
        send_json(conn, {"type": "hello_ack"})

        while True:
            req = recv_json(conn)
            rtype = req.get("type")
            print(f"{addr} -> {rtype}")

            if rtype == "generate":
                prompt = req.get("prompt", "")
                try:
                    source_robot = req.get("source_robot", "") or _DEFAULT_SOURCE_ROBOT
                    target_robot = req.get("target_robot", "") or _DEFAULT_TARGET_ROBOT
                    skill_name = req.get("skill_name", "")
                    target_scene_info = req.get("target_scene_info", [])
                    reference_code = req.get("reference_code", "")

                    grasp_guidance = req.get("grasp_guidance")
                    if grasp_guidance and isinstance(grasp_guidance, str) and grasp_guidance.strip():
                        print(f"[SERVER] Received grasp_guidance: {grasp_guidance[:200]}...")
                    else:
                        grasp_guidance = None

                    result = get_initial_code(
                        skill_name=skill_name,
                        source_robot=source_robot,
                        target_robot=target_robot,
                        reference_code=reference_code,
                        target_scene_info=target_scene_info,
                        grasp_guidance=grasp_guidance
                    )
                    if isinstance(result, tuple):
                        code, num_tokens = result
                    else:
                        code = result
                        num_tokens = None

                    response = {
                        "type": "generate_result",
                        "ok": True,
                        "code": code
                    }
                    if num_tokens is not None:
                        response["num_tokens"] = num_tokens

                    send_json(conn, response)

                except Exception as e:
                    send_json(conn, {
                        "type": "generate_result",
                        "ok": False,
                        "error": str(e)
                    })

            elif rtype == "revise":
                try:
                    inputs = req.get("inputs") or {}
                    constraints = req.get("constraints") or {}

                    current_code = inputs.get("current_code", "")
                    executed_stmt_texts = inputs.get("executed_stmt_texts", [])
                    scene = inputs.get("scene", {})
                    last_step = inputs.get("last_step", {})

                    task_name = req.get("task_name", "unknown_task")
                    robot_type = req.get("robot_type", "") or _DEFAULT_TARGET_ROBOT
                    meta = req.get("meta") or {}
                    source_robot = meta.get("source_robot", "") or _DEFAULT_SOURCE_ROBOT

                    need_more_steps = last_step.get("need_more_steps", False)

                    if need_more_steps:
                        print(f"[REVISE] Task incomplete, generating additional code")
                        print(f"[REVISE] task={task_name}, robot={robot_type}")
                        print(f"[REVISE] executed_stmt_texts count: {len(executed_stmt_texts)}")

                        result = get_additional_code(
                            skill_name=task_name,
                            source_robot=source_robot,
                            target_robot=robot_type,
                            current_code=current_code,
                            executed_stmt_texts=executed_stmt_texts,
                            scene=scene,
                            last_step_payload=last_step,
                        )

                        if isinstance(result, tuple):
                            new_code, num_tokens = result
                        else:
                            new_code = result
                            num_tokens = None

                        response = {
                            "type": "revise_result",
                            "ok": True,
                            "code": new_code
                        }
                        if num_tokens is not None:
                            response["num_tokens"] = num_tokens

                        send_json(conn, response)
                    else:
                        send_json(conn, {
                            "type": "revise_result",
                            "ok": True,
                            "code": current_code
                        })

                except Exception as e:
                    import traceback
                    traceback.print_exc()
                    send_json(conn, {
                        "type": "revise_result",
                        "ok": False,
                        "error": str(e)
                    })

            elif rtype == "repair":
                try:
                    inputs = req.get("inputs") or {}
                    constraints = req.get("constraints") or {}

                    current_code = inputs.get("current_code", "")
                    executed_stmt_texts = inputs.get("executed_stmt_texts", [])
                    failed_stmt_text = inputs.get("failed_stmt_text", "")
                    error_payload = inputs.get("error", {})
                    scene = inputs.get("scene", {})

                    task_name = req.get("task_name", "unknown_task")
                    robot_type = req.get("robot_type", "") or _DEFAULT_TARGET_ROBOT
                    meta = req.get("meta") or {}
                    source_robot = meta.get("source_robot", "") or _DEFAULT_SOURCE_ROBOT

                    print(f"[REPAIR] Failure occurred, repairing code")
                    print(f"[REPAIR] task={task_name}, robot={robot_type}")
                    print(f"[REPAIR] failed_stmt: {failed_stmt_text[:100]}...")
                    print(f"[REPAIR] error: {error_payload}")
                    print(f"[REPAIR] executed_stmt_texts count: {len(executed_stmt_texts)}")

                    result = get_repaired_code_on_failure(
                        skill_name=task_name,
                        source_robot=source_robot,
                        target_robot=robot_type,
                        current_code=current_code,
                        executed_stmt_texts=executed_stmt_texts,
                        failed_stmt_text=failed_stmt_text,
                        error_payload=error_payload,
                        scene=scene,
                    )

                    if isinstance(result, tuple):
                        repaired_code, num_tokens = result
                    else:
                        repaired_code = result
                        num_tokens = None

                    response = {
                        "type": "repair_result",
                        "ok": True,
                        "code": repaired_code
                    }
                    if num_tokens is not None:
                        response["num_tokens"] = num_tokens

                    send_json(conn, response)

                except Exception as e:
                    import traceback
                    traceback.print_exc()
                    send_json(conn, {
                        "type": "repair_result",
                        "ok": False,
                        "error": str(e)
                    })

            elif rtype == "batch_invalid_repair":
                try:
                    inputs = req.get("inputs") or {}
                    constraints = req.get("constraints") or {}

                    current_code = inputs.get("current_code", "")
                    executed_stmt_texts = inputs.get("executed_stmt_texts", [])
                    invalid_statements = inputs.get("invalid_statements", [])
                    scene = inputs.get("scene", {})

                    task_name = req.get("task_name", "unknown_task")
                    robot_type = req.get("robot_type", "") or _DEFAULT_TARGET_ROBOT
                    meta = req.get("meta") or {}

                    source_robot = meta.get("source_robot", "") or _DEFAULT_SOURCE_ROBOT

                    result = get_repaired_statements(
                        skill_name=task_name,
                        source_robot=source_robot,
                        target_robot=robot_type,
                        current_code=current_code,
                        executed_stmt_texts=executed_stmt_texts,
                        invalid_statements=invalid_statements,
                        scene_info=scene.get("objects") if scene else None,
                    )

                    if isinstance(result, tuple):
                        repaired_code, num_tokens = result
                    else:
                        repaired_code = result
                        num_tokens = None

                    response = {
                        "type": "batch_invalid_repair_result",
                        "ok": True,
                        "code": repaired_code,
                    }
                    if num_tokens is not None:
                        response["num_tokens"] = num_tokens

                    send_json(conn, response)

                except Exception as e:
                    import traceback
                    traceback.print_exc()
                    send_json(conn, {
                        "type": "batch_invalid_repair_result",
                        "ok": False,
                        "error": str(e)
                    })

            else:
                send_json(conn, {
                    "type": "error",
                    "ok": False,
                    "error": f"unknown type: {rtype}"
                })

    except Exception as e:
        print(f"{addr} disconnected: {e}")
    finally:
        conn.close()

def serve():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        s.bind(("0.0.0.0", 5000))
        s.listen()
        print(f"listening on 0.0.0.0:5000 (default_llm={'on' if _USE_DEFAULT_LLM else 'off'})")

        while True:
            conn, addr = s.accept()
            threading.Thread(
                target=handle_client,
                args=(conn, addr),
                daemon=True
            ).start()

_DEFAULT_SOURCE_ROBOT = "panda"
_DEFAULT_TARGET_ROBOT = "ur5"


def main():
    global _USE_DEFAULT_LLM, _DEFAULT_SOURCE_ROBOT, _DEFAULT_TARGET_ROBOT

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--use-default-llm",
        action="store_true",
        help="Use default LLM for generate"
    )
    parser.add_argument(
        "--source-robot",
        type=str,
        default="panda",
        choices=["panda", "ur5", "sawyer"],
        help="Default source robot type (default: panda)"
    )
    parser.add_argument(
        "--target-robot",
        type=str,
        default="ur5",
        choices=["ur5", "sawyer"],
        help="Default target robot type (default: ur5)"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility (default: 42)"
    )
    parser.add_argument(
        "--max-vram-fraction", "-f",
        type=float,
        default=None,
        help="Maximum VRAM fraction to use (0.0 ~ 1.0, e.g., 0.5 for 50%%). If not set, no limit."
    )
    args = parser.parse_args()

    if args.max_vram_fraction is not None:
        import torch
        if torch.cuda.is_available():
            torch.cuda.set_per_process_memory_fraction(args.max_vram_fraction)
            print(f"[SERVER] VRAM limit set to: {args.max_vram_fraction * 100:.1f}%")
        else:
            print("[SERVER] Warning: CUDA not available, --max-vram-fraction ignored")

    set_seed(args.seed)
    print(f"[SERVER] Random seed set to: {args.seed}")

    _USE_DEFAULT_LLM = args.use_default_llm
    _DEFAULT_SOURCE_ROBOT = args.source_robot
    _DEFAULT_TARGET_ROBOT = args.target_robot

    if _USE_DEFAULT_LLM:
        init_default_llm()

    print(f"[SERVER] Source robot: {_DEFAULT_SOURCE_ROBOT}")
    print(f"[SERVER] Target robot: {_DEFAULT_TARGET_ROBOT}")
    init_infilling_llm()

    serve()

if __name__ == "__main__":
    main()
