import argparse
import contextlib
import importlib
import math
import random
import sys
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Callable, List, Sequence, Tuple, Literal

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer


def _import_hf_evaluate():
    project_root = Path(__file__).resolve().parent
    original_sys_path = list(sys.path)
    try:
        sys.path = [p for p in original_sys_path if Path(
            p).resolve() != project_root]
        module = importlib.import_module("evaluate")
    finally:
        sys.path = original_sys_path

    sys.modules["evaluate"] = module
    return module


evaluate = _import_hf_evaluate()

MODEL_ID = "Qwen/Qwen3-0.6B"
DEFAULT_TARGET_RANGE = (1.0, 1000.0)


@dataclass(slots=True)
class LinearPerturbationBasis:
    """Stores a reproducible low-rank noise draw for a Linear module."""

    module: torch.nn.Linear
    left: torch.Tensor
    right: torch.Tensor


def _make_generator(seed: int | None) -> torch.Generator | None:
    if seed is None:
        return None
    generator = torch.Generator(device="cpu")
    generator.manual_seed(seed)
    return generator


def sample_linear_bases(
    model: torch.nn.Module,
    max_rank: int,
    *,
    seed: int | None = None,
    selector: Callable[[str, torch.nn.Module], bool] | None = None,
) -> List[LinearPerturbationBasis]:
    """Draw Gaussian bases up to ``max_rank`` for each Linear layer."""

    if max_rank <= 0:
        raise ValueError("max_rank must be positive")

    generator = _make_generator(seed)
    bases: List[LinearPerturbationBasis] = []

    output_embeddings = None
    if hasattr(model, "get_output_embeddings"):
        with contextlib.suppress(Exception):
            output_embeddings = model.get_output_embeddings()

    for name, module in model.named_modules():
        if not isinstance(module, torch.nn.Linear):
            continue

        if module is output_embeddings:
            continue

        if name.endswith("lm_head") or name == "lm_head":
            continue

        if selector is not None and not selector(name, module):
            continue

        out_features, in_features = module.weight.shape
        effective_rank = min(max_rank, out_features, in_features)
        if effective_rank == 0:
            continue

        left = torch.randn(
            (out_features, effective_rank),
            generator=generator,
            dtype=torch.float32,
            device="cpu",
        )
        right = torch.randn(
            (effective_rank, in_features),
            generator=generator,
            dtype=torch.float32,
            device="cpu",
        )
        bases.append(LinearPerturbationBasis(
            module=module, left=left, right=right))

    return bases


def resolve_layer_selector(scope: str) -> Callable[[str, torch.nn.Module], bool] | None:
    """Return a predicate that decides which Linear layers to perturb."""

    if scope == "all":
        return None
    if scope == "attention_qv":
        def selector(name: str, _module: torch.nn.Module) -> bool:
            suffix = name.rsplit(".", 1)[-1]
            if suffix not in {"q_proj", "v_proj"}:
                return False
            return "self_attn" in name or name.endswith(("self_attn.q_proj", "self_attn.v_proj"))

        return selector

    raise ValueError(f"Unknown perturbation scope '{scope}'")


@lru_cache(maxsize=1)
def get_perplexity_metric() -> evaluate.Metric:
    """Load the Hugging Face perplexity metric once."""
    return evaluate.load("perplexity")


@contextlib.contextmanager
def _use_in_memory_model(model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
    """Temporarily patch *from_pretrained to reuse an already loaded model/tokenizer."""
    original_model_ctor = AutoModelForCausalLM.from_pretrained
    original_tokenizer_ctor = AutoTokenizer.from_pretrained

    AutoModelForCausalLM.from_pretrained = classmethod(
        lambda cls, *_args, **_kwargs: model)
    AutoTokenizer.from_pretrained = classmethod(
        lambda cls, *_args, **_kwargs: tokenizer)

    try:
        yield
    finally:
        AutoModelForCausalLM.from_pretrained = original_model_ctor
        AutoTokenizer.from_pretrained = original_tokenizer_ctor


def compute_perplexity(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    texts: Sequence[str],
    device: torch.device,
    max_length: int,
    batch_size: int,
) -> Tuple[float, float]:
    """Return (mean_perplexity, mean_loss) for the provided texts using the evaluate metric."""
    samples = list(texts)
    if not samples:
        raise ValueError("compute_perplexity received an empty text sequence")

    original_pad = tokenizer.pad_token
    try:
        if tokenizer.pad_token is None:
            if tokenizer.eos_token is None:
                raise ValueError(
                    "Tokenizer lacks both pad_token and eos_token; cannot evaluate perplexity.")
            tokenizer.pad_token = tokenizer.eos_token

        metric = get_perplexity_metric()
        metric_device = "cuda" if device.type == "cuda" else "cpu"

        with _use_in_memory_model(model, tokenizer):
            results = metric.compute(
                model_id=model.config._name_or_path,
                predictions=samples,
                batch_size=batch_size,
                device=metric_device,
                max_length=max_length,
                add_start_token=False,
            )
    finally:
        tokenizer.pad_token = original_pad

    per_sample = [float(p) for p in results["perplexities"]]
    if not per_sample or any(p <= 0 for p in per_sample):
        raise ValueError(
            "Perplexity metric returned non-positive values; cannot take logarithm.")

    mean_perplexity = float(results["mean_perplexity"])
    mean_loss = float(sum(math.log(p) for p in per_sample) / len(per_sample))
    return mean_perplexity, mean_loss


def apply_low_rank_perturbation(
    model: torch.nn.Module,
    *,
    rank: int,
    target_norm: float | None = None,
    target_range: Tuple[float, float] = DEFAULT_TARGET_RANGE,
    bases: Sequence[LinearPerturbationBasis] | None = None,
    seed: int | None = None,
    sampling: Literal["uniform", "log-uniform"] = "log-uniform",
    layer_selector: Callable[[str, torch.nn.Module], bool] | None = None,
) -> Tuple[int, float, float, float]:
    """Inject nested low-rank noise and rescale to a target Frobenius norm."""

    if rank <= 0:
        raise ValueError("rank must be positive")

    if bases is None:
        bases = sample_linear_bases(
            model,
            max_rank=rank,
            seed=seed,
            selector=layer_selector,
        )
    elif seed is not None:
        raise ValueError("Provide either `bases` or `seed`, not both")

    min_norm, max_norm = target_range
    if min_norm < 0 or max_norm < 0 or max_norm < min_norm:
        raise ValueError(f"Invalid target norm range: {target_range}")

    if sampling not in {"uniform", "log-uniform"}:
        raise ValueError(f"Unsupported sampling mode '{sampling}'.")

    if target_norm is not None:
        if target_norm < 0:
            raise ValueError("target_norm must be non-negative when provided")
        target_value = float(target_norm)
    else:
        if min_norm == max_norm:
            target_value = float(min_norm)
        else:
            rng = random.Random(seed)
            upper = float(max_norm)
            lower = float(min_norm)
            if sampling == "uniform":
                target_value = float(rng.uniform(lower, upper))
            else:  # log-uniform
                if upper <= 0:
                    raise ValueError(
                        "target norm range must be positive to sample log-uniformly")
                if lower < 0:
                    raise ValueError(
                        "target norm range cannot include negative values for log-uniform sampling")
                if lower == 0.0:
                    # Avoid log(0) by nudging to a tiny positive value without exceeding upper.
                    lower = min(upper, 1e-12)
                log_lower = math.log(lower)
                log_upper = math.log(upper)
                if log_lower == log_upper:
                    target_value = lower
                else:
                    target_value = math.exp(rng.uniform(log_lower, log_upper))

    effective_layers: List[tuple[LinearPerturbationBasis, int]] = []
    total_norm_sq = 0.0

    for base in bases:
        weight = base.module.weight
        out_features, in_features = weight.shape
        max_available_rank = min(
            base.left.shape[1], base.right.shape[0], out_features, in_features)
        effective_rank = min(rank, max_available_rank)
        if effective_rank == 0:
            continue

        left = base.left[:, :effective_rank].to(
            device=weight.device, dtype=weight.dtype)
        right = base.right[:effective_rank, :].to(
            device=weight.device, dtype=weight.dtype)
        delta = left @ right
        total_norm_sq += float(delta.pow(2).sum().item())
        effective_layers.append((base, effective_rank))
        del delta

    if not effective_layers:
        return 0, 0.0, 0.0, target_value

    total_norm = math.sqrt(total_norm_sq)
    if total_norm == 0.0 or target_value == 0.0:
        return len(effective_layers), 0.0, 0.0, target_value

    scale = target_value / total_norm
    achieved_norm = total_norm * abs(scale)

    with torch.no_grad():
        for base, effective_rank in effective_layers:
            weight = base.module.weight
            left = base.left[:, :effective_rank].to(
                device=weight.device, dtype=weight.dtype)
            right = base.right[:effective_rank, :].to(
                device=weight.device, dtype=weight.dtype)
            delta = left @ right
            weight.add_(delta, alpha=scale)
            del delta

    return len(effective_layers), achieved_norm, scale, target_value


def gather_texts(
    dataset_name: str,
    dataset_config: str,
    dataset_split: str,
    text_column: str,
    max_samples: int,
) -> List[str]:
    dataset = load_dataset(dataset_name, dataset_config, split=dataset_split)

    texts: List[str] = []
    for record in dataset:
        text = record.get(text_column)
        if not isinstance(text, str):
            continue
        text = text.strip()
        if not text:
            continue
        texts.append(text)
        if max_samples > 0 and len(texts) >= max_samples:
            break

    if not texts:
        raise ValueError(
            f"Dataset {dataset_name}/{dataset_config}:{dataset_split} yielded no non-empty '{text_column}' entries"
        )

    return texts


def parse_dtype(dtype: str) -> torch.dtype:
    key = dtype.lower()
    if key == "auto":
        return torch.float32

    mapping = {
        "float32": torch.float32,
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
    }
    if key not in mapping:
        raise ValueError(
            "Unsupported dtype '%s'. Choose from auto, float32, float16, bfloat16." % dtype)
    return mapping[key]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Evaluate Qwen/Qwen3-0.6B before and after random low-rank perturbations."
    )
    parser.add_argument("--rank", type=int, default=8,
                        help="Rank of the sampled perturbation matrices")
    parser.add_argument("--max-length", type=int, default=256,
                        help="Tokenizer max_length for evaluation prompts")
    parser.add_argument("--batch-size", type=int, default=4,
                        help="Number of prompts to evaluate per forward pass")
    parser.add_argument("--dataset-name", default="wiki40b",
                        help="Hugging Face dataset repository name")
    parser.add_argument("--dataset-config", default="fr",
                        help="Dataset configuration to load")
    parser.add_argument("--dataset-split", default="test",
                        help="Dataset split to evaluate on")
    parser.add_argument("--text-column", default="text",
                        help="Column in the dataset that stores raw text")
    parser.add_argument(
        "--max-samples",
        type=int,
        default=512,
        help="Maximum number of dataset entries to evaluate (<=0 uses the full split)",
    )
    parser.add_argument("--dtype", default="auto",
                        help="Model dtype: auto|float32|float16|bfloat16")
    parser.add_argument("--device", default="auto",
                        help="Device to run on: auto|cpu|cuda|cuda:<index>")
    parser.add_argument("--trust-remote-code", action="store_true",
                        help="Forward trust_remote_code when loading")
    parser.add_argument("--seed", type=int, default=None,
                        help="Random seed controlling the sampled bases")
    parser.add_argument(
        "--target-norm",
        type=float,
        default=None,
        help=(
            "Frobenius norm to scale the perturbation to (default: sample using the selected distribution)"
        ),
    )
    parser.add_argument(
        "--target-norm-min",
        type=float,
        default=DEFAULT_TARGET_RANGE[0],
        help="Lower bound for the sampled target norm when --target-norm is omitted (must be >= 0)",
    )
    parser.add_argument(
        "--target-norm-max",
        type=float,
        default=DEFAULT_TARGET_RANGE[1],
        help="Upper bound for the sampled target norm when --target-norm is omitted (must be > 0 for sampling)",
    )
    parser.add_argument(
        "--target-norm-sampling",
        choices=["uniform", "log-uniform"],
        default="log-uniform",
        help="Distribution to sample target norms from when --target-norm is omitted",
    )
    parser.add_argument(
        "--perturbation-scope",
        choices=["all", "attention_qv"],
        default="all",
        help="Which linear layers to perturb (excluding lm_head)",
    )
    return parser.parse_args()


def resolve_device(device_arg: str) -> torch.device:
    if device_arg == "auto":
        if torch.cuda.is_available():
            return torch.device("cuda")
        return torch.device("cpu")
    return torch.device(device_arg)


def main() -> None:
    args = parse_args()
    device = resolve_device(args.device)
    dtype = parse_dtype(args.dtype)

    print(
        "Loading dataset "
        f"{args.dataset_name}/{args.dataset_config}:{args.dataset_split} (max_samples={args.max_samples}) ...",
        flush=True,
    )
    texts = gather_texts(
        dataset_name=args.dataset_name,
        dataset_config=args.dataset_config,
        dataset_split=args.dataset_split,
        text_column=args.text_column,
        max_samples=args.max_samples,
    )
    print(f"Collected {len(texts)} texts for evaluation", flush=True)

    print(f"Loading model on {device} with dtype={dtype} ...", flush=True)
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_ID, trust_remote_code=args.trust_remote_code)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        trust_remote_code=args.trust_remote_code,
        dtype=dtype,
    )
    model.to(device)
    model.eval()

    print("Computing baseline perplexity ...", flush=True)
    baseline_ppl, baseline_loss = compute_perplexity(
        model,
        tokenizer,
        texts,
        device,
        args.max_length,
        args.batch_size,
    )
    print(
        f"Baseline    | loss: {baseline_loss:.4f} | perplexity: {baseline_ppl:.3f}")

    print("Applying low-rank perturbations ...", flush=True)
    layer_selector = resolve_layer_selector(args.perturbation_scope)
    bases = sample_linear_bases(
        model,
        max_rank=max(args.rank, 1),
        seed=args.seed,
        selector=layer_selector,
    )
    modified, perturb_norm, scale_factor, target_value = apply_low_rank_perturbation(
        model,
        rank=args.rank,
        target_norm=args.target_norm,
        target_range=(args.target_norm_min, args.target_norm_max),
        bases=bases,
        sampling=args.target_norm_sampling,
        layer_selector=layer_selector,
    )
    print(
        f"Perturbed {modified} linear layers with rank={args.rank}; target ||Δ||_F={target_value:.4f}, "
        f"achieved={perturb_norm:.4f}, scale={scale_factor:.6f}"
    )

    perturbed_ppl, perturbed_loss = compute_perplexity(
        model,
        tokenizer,
        texts,
        device,
        args.max_length,
        args.batch_size,
    )
    print(
        f"Perturbed   | loss: {perturbed_loss:.4f} | perplexity: {perturbed_ppl:.3f}")
    print(
        f"Delta       | loss: {perturbed_loss - baseline_loss:.4f} | perplexity: {perturbed_ppl - baseline_ppl:.3f}")
    print(
        f"Amplitude   | target={target_value:.6f}, achieved={perturb_norm:.6f}")


if __name__ == "__main__":  # pragma: no cover
    main()
