"""Simple chat runner for trained LoRA or full-finetune models."""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, List

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Chat with a trained LoRA adapter or full fine-tuned model."
    )
    parser.add_argument("--sweep-dir", type=Path, help="Path to a sweep directory.")
    parser.add_argument(
        "--run",
        help="Run label inside the sweep (e.g. r8, r32, r128, r512, full).",
    )
    parser.add_argument(
        "--compare",
        action="store_true",
        help="Send each message to all runs in the sweep.",
    )
    parser.add_argument(
        "--no-base",
        action="store_true",
        help="Skip the base model when comparing runs.",
    )
    parser.add_argument(
        "--model-dir",
        type=Path,
        help="Direct path to a model or adapter directory.",
    )
    parser.add_argument(
        "--base-model",
        help="Base model ID/path (needed for LoRA adapters if not discoverable).",
    )
    parser.add_argument("--system-prompt", default=None, help="Optional system prompt.")
    parser.add_argument("--message", help="Single message to send and exit.")
    parser.add_argument(
        "--max-new-tokens",
        type=int,
        default=None,
        help="Optional cap on new tokens (default: no explicit cap).",
    )
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--top-p", type=float, default=0.9)
    parser.add_argument("--no-chat-template", action="store_true")
    parser.add_argument("--trust-remote-code", action="store_true")
    parser.add_argument(
        "--no-pause",
        action="store_true",
        help="Do not pause between model responses when comparing runs.",
    )
    parser.add_argument(
        "--device",
        choices=["auto", "cpu", "cuda"],
        default="auto",
        help="Device to run on.",
    )
    parser.add_argument(
        "--dtype",
        choices=["auto", "fp16", "bf16", "fp32"],
        default="auto",
        help="Model precision.",
    )
    return parser.parse_args()


def resolve_device(choice: str) -> str:
    if choice != "auto":
        return choice
    return "cuda" if torch.cuda.is_available() else "cpu"


def resolve_dtype(choice: str, device: str) -> torch.dtype:
    if choice == "fp16":
        return torch.float16
    if choice == "bf16":
        return torch.bfloat16
    if choice == "fp32":
        return torch.float32
    if device == "cuda":
        return torch.float16
    return torch.float32


def resolve_run_dir(sweep_dir: Path, run: str) -> Path:
    run_dir = sweep_dir / run
    if not run_dir.exists():
        raise FileNotFoundError(f"Run directory not found: {run_dir}")
    return run_dir


def resolve_model_dir(run_dir: Path) -> Path:
    for candidate in ("model", "adapter"):
        path = run_dir / candidate
        if path.exists():
            return path
    return run_dir


def is_adapter_dir(model_dir: Path) -> bool:
    return (model_dir / "adapter_config.json").exists() or (
        model_dir / "adapter_model.safetensors"
    ).exists()


def discover_runs(sweep_dir: Path) -> List[str]:
    runs: List[str] = []
    for entry in sweep_dir.iterdir():
        if not entry.is_dir():
            continue
        name = entry.name
        if name == "full":
            runs.append(name)
        elif name.startswith("r") and name[1:].isdigit():
            runs.append(name)
    runs.sort(
        key=lambda value: (
            1 if value == "full" else 0,
            int(value[1:]) if value != "full" else 0,
        )
    )
    return runs


def load_json(path: Path) -> Dict:
    return json.loads(path.read_text())


def resolve_base_model(
    args: argparse.Namespace, run_dir: Path | None, model_dir: Path
) -> str | None:
    if args.base_model:
        return args.base_model
    adapter_config = model_dir / "adapter_config.json"
    if adapter_config.exists():
        payload = load_json(adapter_config)
        base_id = payload.get("base_model_name_or_path")
        if base_id:
            return base_id
    if run_dir is not None:
        metrics_path = run_dir / "metrics.json"
        if metrics_path.exists():
            payload = load_json(metrics_path)
            base_id = payload.get("model_id")
            if base_id:
                return base_id
    return None


def load_model_and_tokenizer(
    model_dir: Path,
    base_model: str | None,
    *,
    device: str,
    dtype: torch.dtype,
    trust_remote_code: bool,
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
    is_adapter = is_adapter_dir(model_dir)

    if is_adapter:
        if base_model is None:
            raise ValueError("Base model is required to load a LoRA adapter.")
        tokenizer = AutoTokenizer.from_pretrained(
            base_model, trust_remote_code=trust_remote_code
        )
        base = AutoModelForCausalLM.from_pretrained(
            base_model,
            torch_dtype=dtype,
            low_cpu_mem_usage=True,
            trust_remote_code=trust_remote_code,
        )
        model = PeftModel.from_pretrained(base, model_dir)
    else:
        tokenizer_source = base_model if base_model else model_dir
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_source, trust_remote_code=trust_remote_code
        )
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            torch_dtype=dtype,
            low_cpu_mem_usage=True,
            trust_remote_code=trust_remote_code,
        )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model.to(device)
    model.eval()
    return model, tokenizer


def load_base_model_and_tokenizer(
    base_model: str,
    *,
    device: str,
    dtype: torch.dtype,
    trust_remote_code: bool,
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
    tokenizer = AutoTokenizer.from_pretrained(
        base_model, trust_remote_code=trust_remote_code
    )
    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        torch_dtype=dtype,
        low_cpu_mem_usage=True,
        trust_remote_code=trust_remote_code,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model.to(device)
    model.eval()
    return model, tokenizer


def build_prompt(
    tokenizer: AutoTokenizer,
    messages: List[Dict[str, str]],
    no_chat_template: bool,
) -> str:
    if (
        no_chat_template
        or not hasattr(tokenizer, "apply_chat_template")
        or not getattr(tokenizer, "chat_template", None)
    ):
        parts = []
        for message in messages:
            role = message["role"]
            content = message["content"]
            parts.append(f"{role}: {content}")
        parts.append("assistant:")
        return "\n".join(parts)
    return tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )


def resolve_max_new_tokens(
    model: AutoModelForCausalLM,
    input_length: int,
    max_new_tokens: int | None,
) -> int | None:
    if max_new_tokens is not None:
        return max_new_tokens
    max_positions = getattr(model.config, "max_position_embeddings", None)
    if max_positions is None:
        max_positions = getattr(model.config, "max_length", None)
    if max_positions is None:
        return None
    available = int(max_positions) - input_length
    return 1 if available < 1 else available


def generate_reply(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    messages: List[Dict[str, str]],
    *,
    no_chat_template: bool,
    max_new_tokens: int | None,
    temperature: float,
    top_p: float,
    device: str,
) -> str:
    prompt = build_prompt(tokenizer, messages, no_chat_template)
    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = {key: value.to(device) for key, value in inputs.items()}
    do_sample = temperature > 0
    effective_max_new_tokens = resolve_max_new_tokens(
        model,
        input_length=int(inputs["input_ids"].shape[-1]),
        max_new_tokens=max_new_tokens,
    )
    generate_kwargs = {
        "do_sample": do_sample,
        "temperature": temperature if do_sample else None,
        "top_p": top_p if do_sample else None,
        "pad_token_id": tokenizer.pad_token_id,
        "eos_token_id": tokenizer.eos_token_id,
    }
    if effective_max_new_tokens is not None:
        generate_kwargs["max_new_tokens"] = effective_max_new_tokens
    with torch.inference_mode():
        output = model.generate(**inputs, **generate_kwargs)
    generated = output[0][inputs["input_ids"].shape[-1] :]
    return tokenizer.decode(generated, skip_special_tokens=True).strip()


def _maybe_pause(enabled: bool) -> None:
    if not enabled:
        return
    try:
        input("Press Enter for the next response...")
    except EOFError:
        return


def resolve_compare_base_model(
    args: argparse.Namespace, sweep_dir: Path, runs: List[str]
) -> str | None:
    if args.base_model:
        return args.base_model
    for run in runs:
        run_dir = resolve_run_dir(sweep_dir, run)
        model_dir = resolve_model_dir(run_dir)
        base_id = resolve_base_model(args, run_dir, model_dir)
        if base_id:
            return base_id
    return None


def compare_runs(args: argparse.Namespace) -> None:
    if args.sweep_dir is None:
        raise SystemExit("--compare requires --sweep-dir.")
    if args.model_dir is not None or args.run is not None:
        print("Ignoring --model-dir/--run because --compare is enabled.")

    device = resolve_device(args.device)
    dtype = resolve_dtype(args.dtype, device)
    runs = discover_runs(args.sweep_dir)
    if not runs:
        raise SystemExit(f"No runs found in {args.sweep_dir}")

    base_model_id = None
    if not args.no_base:
        base_model_id = resolve_compare_base_model(args, args.sweep_dir, runs)
        if base_model_id is None:
            print("Base model not found; skipping base responses.")

    run_specs: List[Dict[str, str]] = []
    if base_model_id is not None:
        run_specs.append({"label": "base", "kind": "base"})
    for run in runs:
        run_specs.append({"label": run, "kind": "run"})

    histories: Dict[str, List[Dict[str, str]]] = {}
    for spec in run_specs:
        run = spec["label"]
        history: List[Dict[str, str]] = []
        if args.system_prompt:
            history.append({"role": "system", "content": args.system_prompt})
        histories[run] = history

    def respond_to_all(message: str) -> None:
        for idx, spec in enumerate(run_specs):
            label = spec["label"]
            if spec["kind"] == "base":
                if base_model_id is None:
                    continue
                model, tokenizer = load_base_model_and_tokenizer(
                    base_model_id,
                    device=device,
                    dtype=dtype,
                    trust_remote_code=args.trust_remote_code,
                )
                run_history = histories[label]
            else:
                run_dir = resolve_run_dir(args.sweep_dir, label)
                model_dir = resolve_model_dir(run_dir)
                base_model = resolve_base_model(args, run_dir, model_dir) or base_model_id
                model, tokenizer = load_model_and_tokenizer(
                    model_dir,
                    base_model,
                    device=device,
                    dtype=dtype,
                    trust_remote_code=args.trust_remote_code,
                )
                run_history = histories[label]
            run_history.append({"role": "user", "content": message})
            reply = generate_reply(
                model,
                tokenizer,
                run_history,
                no_chat_template=args.no_chat_template,
                max_new_tokens=args.max_new_tokens,
                temperature=args.temperature,
                top_p=args.top_p,
                device=device,
            )
            run_history.append({"role": "assistant", "content": reply})
            print(f"\n[{label}]")
            print(reply)
            del model, tokenizer
            if device == "cuda":
                torch.cuda.empty_cache()
            _maybe_pause(not args.no_pause and idx < len(run_specs) - 1)

    if args.message:
        respond_to_all(args.message)
        return

    print("Enter messages (Ctrl-D or 'exit' to quit).")
    while True:
        try:
            user_text = input("user> ").strip()
        except EOFError:
            break
        if not user_text:
            continue
        if user_text.lower() in {"exit", "quit"}:
            break
        respond_to_all(user_text)


def main() -> None:
    args = parse_args()
    if args.compare:
        compare_runs(args)
        return
    if args.model_dir is None:
        if args.sweep_dir is None or args.run is None:
            raise SystemExit("Provide --model-dir or both --sweep-dir and --run.")
        run_dir = resolve_run_dir(args.sweep_dir, args.run)
        model_dir = resolve_model_dir(run_dir)
    else:
        if (args.model_dir / "metrics.json").exists() or (
            args.model_dir / "adapter"
        ).exists() or (args.model_dir / "model").exists():
            run_dir = args.model_dir
            model_dir = resolve_model_dir(run_dir)
        else:
            run_dir = args.model_dir.parent
            model_dir = args.model_dir

    device = resolve_device(args.device)
    dtype = resolve_dtype(args.dtype, device)
    base_model = resolve_base_model(args, run_dir, model_dir)
    model, tokenizer = load_model_and_tokenizer(
        model_dir,
        base_model,
        device=device,
        dtype=dtype,
        trust_remote_code=args.trust_remote_code,
    )

    messages: List[Dict[str, str]] = []
    if args.system_prompt:
        messages.append({"role": "system", "content": args.system_prompt})

    if args.message:
        messages.append({"role": "user", "content": args.message})
        reply = generate_reply(
            model,
            tokenizer,
            messages,
            no_chat_template=args.no_chat_template,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
            top_p=args.top_p,
            device=device,
        )
        print(reply)
        return

    print("Enter messages (Ctrl-D or 'exit' to quit).")
    while True:
        try:
            user_text = input("user> ").strip()
        except EOFError:
            break
        if not user_text:
            continue
        if user_text.lower() in {"exit", "quit"}:
            break
        messages.append({"role": "user", "content": user_text})
        reply = generate_reply(
            model,
            tokenizer,
            messages,
            no_chat_template=args.no_chat_template,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
            top_p=args.top_p,
            device=device,
        )
        print(f"assistant> {reply}")
        messages.append({"role": "assistant", "content": reply})


if __name__ == "__main__":
    main()
