import json
from argparse import ArgumentParser
from collections import defaultdict
from typing import Any, Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
from datasets import Dataset, load_from_disk
from tqdm import tqdm
from transformers import AutoModelForCausalLM

from fair_gpt.evaluation.pr import pr_knn_from_embeddings
from fair_gpt.evaluation.pr_topo import compute_top_pr
from fair_gpt.training.lora_training import process_data
from fair_gpt.utils import (
    LeftPaddingCompatibleDataCollatorForLM,
    per_example_loss,
)


def normalize_gender(g):
    if not isinstance(g, str):
        return "other"
    g_lower = g.strip().lower()
    if g_lower in {"male", "m"}:
        return "male"
    if g_lower in {"female", "f"}:
        return "female"
    return "other"


def dataset_to_array(ds: Dataset, field: str) -> np.ndarray:
    """Stacks a column of arrays into a 2D numpy array."""
    try:
        return np.stack(ds[field])
    except Exception as e:
        raise ValueError(f"Failed to stack field '{field}' into an array: {e}")


def safe_pr(
    E_ref: np.ndarray,
    E_out: np.ndarray,
    k: int,
    pca_var: float,
    n_samples: int,
) -> Tuple[float, float, Dict[str, Any]]:
    """Run pr_knn_from_embeddings with a safe sample size."""
    if len(E_ref) == 0 or len(E_out) == 0:
        return float("nan"), float("nan"), {"N_ref": 0, "N_out": 0, "note": "empty set"}
    n = min(n_samples, len(E_ref), len(E_out))
    precision, recall, info = pr_knn_from_embeddings(
        E_ref=E_ref,
        E_out=E_out,
        k=k,
        pca_var=pca_var,
        n_samples=n,
    )
    topo_res = compute_top_pr(
        real_features=E_ref,
        fake_features=E_out,
        alpha=0.1,
        kernel="cosine",
        random_proj=True,
        f1_score=True,
        l2norm=False,
    )
    p_topo, r_topo, f1_topo = (
        topo_res["fidelity"],
        topo_res["diversity"],
        topo_res["Top_F1"],
    )

    return precision, recall, p_topo, r_topo, f1_topo, info


def compute_run_metrics(
    ref_data: Dataset,
    gen_data: Dataset,
    k: int,
    pca_var: float,
    n_samples: int,
) -> Dict[str, Any]:
    """Compute overall, per gender, and harmonized metrics for one run."""
    # Overall PR
    ref_embed_all = dataset_to_array(ref_data, "embedding")
    gen_embed_all = dataset_to_array(gen_data, "embedding")
    precision_all, recall_all, p_topo, r_topo, f1_topo, info_all = safe_pr(
        ref_embed_all, gen_embed_all, k=k, pca_var=pca_var, n_samples=n_samples
    )

    results = {
        "params": {"k": k, "pca_var": pca_var},
        "counts": {"ref_total": len(ref_data), "gen_total": len(gen_data)},
        "precision": float(precision_all),
        "recall": float(recall_all),
        "p_topo": float(p_topo),
        "r_topo": float(r_topo),
        "f1_topo": float(f1_topo),
    }

    return results


def split_gen_data(
    gen_data_full: Dataset,
    n_slices: int,
    slice_size: int,
    seed: int,
) -> List[Dataset]:
    """Shuffle once, then make non overlapping slices of equal size."""
    n = len(gen_data_full)
    rng = np.random.RandomState(seed)
    indices = np.arange(n)
    rng.shuffle(indices)

    total_needed = min(n_slices * slice_size, n)
    indices = indices[:total_needed]

    slices = []
    for i in range(n_slices):
        start = i * slice_size
        end = start + slice_size
        idx = indices[start:end]
        if len(idx) == 0:
            break
        slices.append(gen_data_full.select(idx.tolist()))
    return slices


def print_summary(result_dict):
    dict_df = defaultdict(list)
    metrics = [
        "precision_mean",
        "precision_std",
        "recall_mean",
        "recall_std",
        "p_topo_mean",
        "r_topo_mean",
        "p_topo_std",
        "r_topo_std",
    ]
    for k, v in result_dict.items():
        for metric in metrics:
            value = v[metric]
            dict_df[k].append(value)

    df = pd.DataFrame.from_dict(dict_df, orient="index", columns=metrics)
    print(df.to_string(float_format=lambda x: f"{x:.4f}"))


def compute_all_runs_metrics(
    ref_data: Dataset,
    gen_data: Dataset,
    k: int,
    pca_var: float,
    n_samples: int,
    n_runs: int,
    seed: int,
    run_name: str = "overall",
) -> Dict[str, Any]:
    slices = split_gen_data(gen_data, n_slices=n_runs, slice_size=n_samples, seed=seed)
    if len(slices) < n_runs:
        print(
            f"Warning: requested {n_runs} runs, but only {len(slices)} slices could be made from gen_data of size {len(gen_data)}."
        )

    run_results = []

    for i, gen_slice in enumerate(slices, start=1):
        print(f"\nRun {run_name} | {i}: slice size {len(gen_slice)}")
        res = compute_run_metrics(
            ref_data=ref_data,
            gen_data=gen_slice,
            k=k,
            pca_var=pca_var,
            n_samples=n_samples,
        )
        run_results.append(res)
    dict_results = {"runs": run_results}
    # average results across runs with std
    for key in ["precision", "recall", "p_topo", "r_topo"]:
        vals = [r[key] for r in run_results]
        arr = np.array(vals, dtype=float)
        mean = float(np.nanmean(arr))
        std = float(np.nanstd(arr, ddof=1)) if len(arr) > 1 else 0.0
        print(f"{run_name} {key}: mean {mean:.4f}, std {std:.4f}")
        dict_results[f"{key}_mean"] = mean
        dict_results[f"{key}_std"] = std
    return dict_results


def launch_pr_eval(
    ref_embed_path,
    gen_embed_path,
    output_path,
    k=4,
    max_samples=3000,
    n_runs=5,
    pca_var=0.90,
    seed=1234,
):
    # Load datasets
    ref_data = load_from_disk(ref_embed_path)
    gen_data_full = load_from_disk(gen_embed_path)

    # Normalize provided gender labels in the reference set
    ref_data = ref_data.map(lambda x: {"gender": normalize_gender(x.get("gender"))})

    # Classify generated texts with the NLTK based heuristic

    # Build slices: by default, 5 slices of 3000 each from 15000 rows

    subsets = {
        "overall": lambda x: True,
        "male": lambda x: x["gender"] == "male",
        "female": lambda x: x["gender"] != "male",
    }
    data_dict = {}
    for key, fn in subsets.items():
        ref_subset = ref_data.filter(fn)
        print(f"Reference {key} subset size: {len(ref_subset)}")
        gen_subset = gen_data_full.filter(fn)
        print(f"Generated {key} subset size: {len(gen_subset)}")
        data_dict[key] = {"ref": ref_subset, "gen": gen_subset}
        assert len(ref_subset) <= len(gen_subset), (
            f"Reference {key} subset larger than generated."
        )
    min_samples = min(len(d["ref"]) for d in data_dict.values())
    n_samples = min(max_samples, min_samples)
    print(
        f"Computing precision and recall with k={k}, n_samples={n_samples}, runs={n_runs}"
    )
    result_dict = {}
    for key, fn in subsets.items():
        ref_subset = data_dict[key]["ref"]
        gen_subset = data_dict[key]["gen"]

        all_run_results = compute_all_runs_metrics(
            ref_data=ref_subset,
            gen_data=gen_subset,
            k=k,
            pca_var=pca_var,
            n_samples=n_samples,
            n_runs=n_runs,
            seed=seed,
            run_name=key,
        )
        result_dict[key] = all_run_results

    print_summary(result_dict)

    output = {
        "meta": {
            "k": k,
            "pca_var": pca_var,
            "n_samples_per_run": n_samples,
            "n_runs": n_runs,
            "seed": seed,
            "ref_total": len(ref_data),
            "gen_total": len(gen_data_full),
        },
        "results": result_dict,
    }

    # Save JSON
    with open(output_path, "w") as f:
        json.dump(output, f, indent=4)
    print(f"Saved results to {output_path}")


def launch_nll_eval(cfg, model, tokenizer, is_chat, output_path):
    data = process_data(cfg, tokenizer, is_chat)
    collator = LeftPaddingCompatibleDataCollatorForLM(tokenizer=tokenizer)
    losses = []
    model.eval()
    for i in tqdm(range(0, len(data), cfg.eval.batch_size)):
        batch = [d for d in data[i : i + cfg.eval.batch_size]]
        batch = collator(batch)
        batch = {k: v.to(model.device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        loss = per_example_loss(outputs, batch)[0]
        loss = loss.to(torch.float32).cpu()
        losses.extend(loss.numpy().tolist())
    results = {
        "nll_mean": float(np.mean(losses)),
        "nll_std": float(np.std(losses, ddof=1)),
        "all_losses": losses,
    }
    with open(output_path, "w") as f:
        json.dump(results, f, indent=4)
    print(f"Saved results to {output_path}")
    return results


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--ref_embed_path",
        type=str,
        default="results_dir/wikibio_test/hf_dataset_dir",  # expects a HF dataset dir
        help="Path to reference HF dataset saved with datasets.save_to_disk, with fields: embedding, gender.",
    )
    parser.add_argument(
        "--gen_embed_path",
        type=str,
        default="results_dir/wikibio/llama3.1-8b-pt/shots/hf_dataset_dir",
        help="Path to generated HF dataset dir with fields: embedding, text.",
    )
    parser.add_argument(
        "--k", type=int, default=4, help="k for kNN precision and recall."
    )
    parser.add_argument("--pca_var", type=float, default=0.90, help="PCA variance.")
    parser.add_argument(
        "--max_n_samples", type=int, default=3000, help="Max samples per run."
    )
    parser.add_argument(
        "--n_runs",
        type=int,
        default=5,
        help="Number of non overlapping slices of gen_data to evaluate.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=1234,
        help="Seed used to shuffle gen_data before slicing.",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        default="results_dir/wikibio/llama3.1-8b-pt/shots/pr_sliced.json",
        help="Where to save metrics JSON.",
    )
    args = parser.parse_args()

    launch_pr_eval(
        ref_embed_path=args.ref_embed_path,
        gen_embed_path=args.gen_embed_path,
        output_path=args.output_path,
        k=args.k,
        max_samples=args.max_n_samples,
        n_runs=args.n_runs,
        pca_var=args.pca_var,
        seed=args.seed,
    )
