import argparse
import json
import math
from pathlib import Path
from typing import Dict, List, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from evaluation_runner import (
    DEFAULT_TARGET_RANGE,
    MODEL_ID,
    apply_low_rank_perturbation,
    compute_perplexity,
    gather_texts,
    parse_dtype,
    resolve_device,
    resolve_layer_selector,
    sample_linear_bases,
)
from plot_results import plot_records

PLOT_TITLE = "Perturbation Amplitude vs. Perplexity"


def snapshot_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
    """Clone model parameters onto CPU so we can restore them before each perturbation."""
    return {name: param.detach().to("cpu").clone() for name, param in model.state_dict().items()}


def restore_model(model: AutoModelForCausalLM, base_state: Dict[str, torch.Tensor], device: torch.device) -> None:
    model.load_state_dict(base_state, strict=True)
    model.to(device)
    model.eval()


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "Sample random low-rank perturbations of Qwen/Qwen3-0.6B, log metrics, and plot "
            "perplexity versus perturbation amplitude."
        )
    )
    parser.add_argument("--ranks", type=int, nargs="+", default=[2, 4, 8, 16], help="Ranks to evaluate")
    parser.add_argument("--trials", type=int, default=25, help="Number of random perturbations per rank")
    parser.add_argument("--target-norm", type=float, default=None, help="Fixed target Frobenius norm (default: sample)")
    parser.add_argument(
        "--target-norm-min",
        type=float,
        default=DEFAULT_TARGET_RANGE[0],
        help="Lower bound for sampled target norms when --target-norm is omitted (must be >= 0)",
    )
    parser.add_argument(
        "--target-norm-max",
        type=float,
        default=DEFAULT_TARGET_RANGE[1],
        help="Upper bound for sampled target norms when --target-norm is omitted (must be > 0 for sampling)",
    )
    parser.add_argument(
        "--target-norm-sampling",
        choices=["uniform", "log-uniform"],
        default="log-uniform",
        help="Distribution to sample target norms from when --target-norm is omitted",
    )
    parser.add_argument(
        "--perturbation-scope",
        choices=["all", "attention_qv"],
        default="all",
        help="Which linear layers to perturb (excluding lm_head)",
    )
    parser.add_argument("--max-length", type=int, default=256, help="Tokenizer max_length for evaluation prompts")
    parser.add_argument("--batch-size", type=int, default=4, help="Number of prompts to evaluate per forward pass")
    parser.add_argument("--dataset-name", default="wiki40b", help="Hugging Face dataset repository name")
    parser.add_argument("--dataset-config", default="fr", help="Dataset configuration to load")
    parser.add_argument("--dataset-split", default="test", help="Dataset split to evaluate on")
    parser.add_argument("--text-column", default="text", help="Column in the dataset that stores raw text")
    parser.add_argument(
        "--max-samples",
        type=int,
        default=256,
        help="Maximum number of dataset entries to evaluate (<=0 uses the full split)",
    )
    parser.add_argument("--dtype", default="auto", help="Model dtype: auto|float32|float16|bfloat16")
    parser.add_argument("--device", default="auto", help="Device to run on: auto|cpu|cuda|cuda:<index>")
    parser.add_argument("--trust-remote-code", action="store_true", help="Forward trust_remote_code when loading")
    parser.add_argument("--seed", type=int, default=0, help="Base seed used to derive per-trial seeds")
    parser.add_argument("--output-json", type=Path, default=Path("results/perturbation/perturbation_results.json"))
    parser.add_argument("--output-plot", type=Path, default=Path("results/perturbation/perturbation_scatter.png"))
    return parser.parse_args()


def validate_ranges(rank_values: List[int], norm_range: Tuple[float, float]) -> List[int]:
    ranks = [r for r in dict.fromkeys(rank_values) if r > 0]
    if not ranks:
        raise ValueError("--ranks must contain at least one positive integer")
    min_norm, max_norm = norm_range
    if min_norm < 0 or max_norm < 0 or max_norm < min_norm:
        raise ValueError("target norm range must satisfy 0 <= min <= max")
    return ranks


def main() -> None:
    args = parse_args()
    device = resolve_device(args.device)
    dtype = parse_dtype(args.dtype)

    target_range = (args.target_norm_min, args.target_norm_max)
    ranks = validate_ranges(args.ranks, target_range)
    layer_selector = resolve_layer_selector(args.perturbation_scope)

    print(
        "Loading dataset "
        f"{args.dataset_name}/{args.dataset_config}:{args.dataset_split} (max_samples={args.max_samples}) ...",
        flush=True,
    )
    texts = gather_texts(
        dataset_name=args.dataset_name,
        dataset_config=args.dataset_config,
        dataset_split=args.dataset_split,
        text_column=args.text_column,
        max_samples=args.max_samples,
    )
    print(f"Collected {len(texts)} texts for evaluation", flush=True)

    print(f"Loading tokenizer (trust_remote_code={args.trust_remote_code}) ...", flush=True)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=args.trust_remote_code)

    print(f"Loading model on {device} with dtype={dtype} ...", flush=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        trust_remote_code=args.trust_remote_code,
        dtype=dtype,
    )
    model.to(device)
    model.eval()

    print("Running baseline pass (shared across all perturbations) ...", flush=True)
    baseline_ppl, baseline_loss = compute_perplexity(
        model,
        tokenizer,
        texts,
        device,
        args.max_length,
        args.batch_size,
    )
    print(f"Baseline    | loss: {baseline_loss:.4f} | perplexity: {baseline_ppl:.3f}")

    base_state = snapshot_state_dict(model)
    max_rank = max(ranks)

    records: List[Dict[str, float]] = []
    total_runs = len(ranks) * args.trials

    for trial in range(args.trials):
        trial_seed = args.seed + trial if args.seed is not None else None
        print(f"\nTrial {trial + 1}/{args.trials} | seed={trial_seed}")
        bases = sample_linear_bases(
            model,
            max_rank=max_rank,
            seed=trial_seed,
            selector=layer_selector,
        )

        for rank_index, rank in enumerate(ranks):
            run_index = trial * len(ranks) + rank_index + 1
            print(f"[{run_index}/{total_runs}] rank={rank}")

            restore_model(model, base_state, device)
            layers_modified, perturb_norm, scale_factor, target_value = apply_low_rank_perturbation(
                model,
                rank=rank,
                target_norm=args.target_norm,
                target_range=target_range,
                bases=bases,
                sampling=args.target_norm_sampling,
                layer_selector=layer_selector,
            )

            perturbed_ppl, perturbed_loss = compute_perplexity(
                model,
                tokenizer,
                texts,
                device,
                args.max_length,
                args.batch_size,
            )

            delta_loss = perturbed_loss - baseline_loss
            delta_perplexity = perturbed_ppl - baseline_ppl
            sensitivity = delta_loss / perturb_norm if perturb_norm else float("nan")

            record = {
                "rank": rank,
                "trial": trial,
                "seed": trial_seed,
                "perturbation_scope": args.perturbation_scope,
                "target_norm_sampling": args.target_norm_sampling,
                "target_norm": target_value,
                "perturbation_norm": perturb_norm,
                "scale_factor": scale_factor,
                "layers_modified": layers_modified,
                "baseline_loss": baseline_loss,
                "baseline_perplexity": baseline_ppl,
                "perturbed_loss": perturbed_loss,
                "perturbed_perplexity": perturbed_ppl,
                "delta_loss": delta_loss,
                "delta_perplexity": delta_perplexity,
                "sensitivity": sensitivity,
            }
            records.append(record)

            sens_str = f"{sensitivity:+.6f}" if math.isfinite(sensitivity) else "nan"
            print(
                f"  loss={perturbed_loss:.4f} (Δ {delta_loss:+.4f}) | "
                f"ppl={perturbed_ppl:.3f} (Δ {delta_perplexity:+.3f}) | "
                f"||Δ||_F={perturb_norm:.4f} (target {target_value:.4f}) | scale={scale_factor:.6f} | sens={sens_str}"
            )

            if device.type == "cuda":
                torch.cuda.empty_cache()

    if args.output_json:
        args.output_json.parent.mkdir(parents=True, exist_ok=True)
        args.output_json.write_text(json.dumps(records, indent=2))
        print(f"Saved records to {args.output_json}")

    if args.output_plot:
        plot_records(records, output=args.output_plot, show=False, title=PLOT_TITLE)
    else:
        plot_records(records, output=None, show=True, title=PLOT_TITLE)


if __name__ == "__main__":  # pragma: no cover
    main()
