import argparse
import json
import math
import random
from pathlib import Path
from typing import Dict, Any, List, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


def init_weights_xavier(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)


class FCNet(nn.Module):
    """fc/fb network used in your EM stage: forward(z, xlatent) -> feature vector or scalar"""
    def __init__(self, z_dim: int, latent_dim: int, out_dim: int, hidden_sizes: List[int]):
        super().__init__()
        layers = []
        input_dim = z_dim + latent_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(input_dim, h))
            layers.append(nn.Tanh())
            input_dim = h
        layers.append(nn.Linear(input_dim, out_dim))
        self.net = nn.Sequential(*layers)
        self.apply(init_weights_xavier)

    def forward(self, z, xlatent):
        inp = torch.cat([z, xlatent], dim=1)
        return self.net(inp)


class ZLimits:
    """5D uniform sampler; one z per pair, reused on both sides."""
    def __init__(self, lows=(1, 200, 0, 0, 8), highs=(15, 700, 1, 1, 12)):
        self.lows = torch.tensor(lows, dtype=torch.float32)
        self.highs = torch.tensor(highs, dtype=torch.float32)

    def sample(self, n, device):
        lows = self.lows.to(device).unsqueeze(0)
        highs = self.highs.to(device).unsqueeze(0)
        u = torch.rand(n, 5, device=device)
        return lows + (highs - lows) * u


class PairIndexDataset(Dataset):
    """Random pairs (i, j) with i != j."""
    def __init__(self, N, num_pairs, seed=None):
        self.N = N
        self.num_pairs = num_pairs
        self.rng = random.Random(seed)

    def __len__(self):
        return self.num_pairs

    def __getitem__(self, _):
        i = self.rng.randrange(self.N)
        j = self.rng.randrange(self.N - 1)
        if j >= i:
            j += 1
        return i, j


class AM(nn.Module):
    """AM([x,z]) -> scalar in [0,1]"""
    def __init__(self, x_dim, z_dim=5, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(x_dim + z_dim, hidden),
            nn.Sigmoid(),
            nn.Linear(hidden, hidden),
            nn.Sigmoid(),
            nn.Linear(hidden, 1),
            nn.Sigmoid(),
        )

    def forward(self, xz):
        return self.net(xz)


@torch.no_grad()
def compute_target_sigmoid(fc, fb, z, qi, qj, eps=1e-8):
    """
    target = sigmoid( ||fc(z,qi) - fb(z,qj)|| - ||fc(z,qj) - fb(z,qi)|| )
    """
    a = fc(z, qi)
    b = fb(z, qj)
    c = fc(z, qj)
    d = fb(z, qi)
    r1 = a - b
    r2 = c - d
    n1 = torch.sqrt((r1 * r1).sum(dim=1, keepdim=True) + eps)
    n2 = torch.sqrt((r2 * r2).sum(dim=1, keepdim=True) + eps)
    u = n1 - n2
    return torch.sigmoid(u)


def _training_step(am, fc, fb, X, Q, idx_i, idx_j, z_limits, device, mse):
    idx_i = idx_i.to(device).long()
    idx_j = idx_j.to(device).long()

    x_i, x_j = X[idx_i], X[idx_j]  # [B,1]
    q_i, q_j = Q[idx_i], Q[idx_j]  # [B,1]
    z = z_limits.sample(n=x_i.size(0), device=device)  # [B,5]

    logit_i = am(torch.cat([x_i, z], dim=1))  # [B,1] in [0,1] due to final sigmoid
    logit_j = am(torch.cat([x_j, z], dim=1))
    preds = torch.sigmoid(logit_i - logit_j)  # [B,1]

    targets = compute_target_sigmoid(fc, fb, z, q_i, q_j)  # [B,1]

    loss = mse(preds, targets)

    with torch.no_grad():
        acc = ((preds > 0.5) == (targets > 0.5)).float().mean().item()
        stats = {
            "loss": float(loss.item()),
            "avg_pred": float(preds.mean().item()),
            "avg_target": float(targets.mean().item()),
            "acc@0.5": float(acc),
        }
    return loss, stats


def train_am_with_fc_fb(
    X, Q, fc, fb,
    *,
    z_limits=None,
    hidden_am=128,
    epochs=10,
    pairs_per_epoch=8000,
    batch_size=256,
    lr=1e-3,
    weight_decay=1e-5,
    seed=42,
    device=None,
):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))


    if Q.ndim != 2 or Q.shape[1] != 1:
        raise ValueError(f"Q must be [N,1], got {tuple(Q.shape)}")
    if X.ndim != 2 or X.shape[1] != 1:
        raise ValueError(f"X must be [N,1], got {tuple(X.shape)}")

    N = X.shape[0]
    X = X.to(device)
    Q = Q.to(device)

    if z_limits is None:
        z_limits = ZLimits()

    # Build AM, freeze fc/fb
    am = AM(x_dim=1, z_dim=5, hidden=hidden_am).to(device)

    fc = fc.to(device).eval()
    fb = fb.to(device).eval()
    for p in fc.parameters():
        p.requires_grad = False
    for p in fb.parameters():
        p.requires_grad = False

    opt = torch.optim.AdamW(am.parameters(), lr=lr, weight_decay=weight_decay)
    mse = nn.MSELoss()

    dl = DataLoader(
        PairIndexDataset(N, pairs_per_epoch * epochs, seed=seed),
        batch_size=batch_size,
        shuffle=True,
        drop_last=False,
    )

    am.train()
    steps_per_epoch = max(1, math.ceil(pairs_per_epoch / batch_size))
    step = 0

    for ep in range(1, epochs + 1):
        run_loss = run_acc = run_p = run_t = 0.0
        steps = 0

        for _ in range(steps_per_epoch):
            idx_left, idx_right = next(iter(dl))  
            opt.zero_grad(set_to_none=True)
            loss, stats = _training_step(am, fc, fb, X, Q, idx_left, idx_right, z_limits, device, mse)
            loss.backward()
            opt.step()

            run_loss += stats["loss"]
            run_acc += stats["acc@0.5"]
            run_p += stats["avg_pred"]
            run_t += stats["avg_target"]
            steps += 1
            step += 1

        print(
            f"[Epoch {ep:02d}] loss={run_loss/steps:.4f} "
            f"acc@0.5={run_acc/steps:.3f} "
            f"ŷ={run_p/steps:.3f} t̄={run_t/steps:.3f}"
        )

    am.eval()
    return am



def load_json(path: Path) -> Dict[str, Any]:
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)


def load_fc_fb(
    fc_ckpt: Path,
    fb_ckpt: Path,
    *,
    z_dim: int,
    latent_dim: int,
    out_dim: int,
    hidden_sizes: List[int],
    map_location="cpu",
):
    """
    Loads fc/fb weights. Supports:
    - checkpoint containing {"state_dict": ...} or raw state_dict
    """
    fc = FCNet(z_dim=z_dim, latent_dim=latent_dim, out_dim=out_dim, hidden_sizes=hidden_sizes)
    fb = FCNet(z_dim=z_dim, latent_dim=latent_dim, out_dim=out_dim, hidden_sizes=hidden_sizes)

    fc_obj = torch.load(fc_ckpt, map_location=map_location)
    fb_obj = torch.load(fb_ckpt, map_location=map_location)

    fc_sd = fc_obj["state_dict"] if isinstance(fc_obj, dict) and "state_dict" in fc_obj else fc_obj
    fb_sd = fb_obj["state_dict"] if isinstance(fb_obj, dict) and "state_dict" in fb_obj else fb_obj

    fc.load_state_dict(fc_sd, strict=True)
    fb.load_state_dict(fb_sd, strict=True)
    return fc, fb


def build_good_metrics(metrics: pd.DataFrame, pred: Dict[str, Dict[str, float]]) -> pd.DataFrame:
    """
    Keep only rows where:
      sequence exists in pred AND codec_crf exists in pred[sequence]
    Adds column Q.
    """
    need_cols = {"sequence", "codec", "crf"}
    missing = need_cols - set(metrics.columns)
    if missing:
        raise KeyError(f"metrics_csv is missing required columns: {sorted(missing)}")

    good = []
    for _, r in metrics.iterrows():
        seq = r["sequence"]
        if seq not in pred:
            continue
        codec_crf = f"{r['codec']}_{int(r['crf']) if str(r['crf']).isdigit() else r['crf']}"
        if codec_crf not in pred[seq]:
            continue
        dr = dict(r)
        dr["Q"] = float(pred[seq][codec_crf])
        good.append(dr)

    if not good:
        raise RuntimeError("No aligned rows between metrics_csv and pred_json. Check sequence/codec/crf naming.")

    return pd.DataFrame(good)


def tensors_from_metric(df: pd.DataFrame, metric_col: str) -> Tuple[torch.Tensor, torch.Tensor]:
    if metric_col not in df.columns:
        raise KeyError(f"Metric column not found: '{metric_col}'")

    x = pd.to_numeric(df[metric_col], errors="coerce")
    q = pd.to_numeric(df["Q"], errors="coerce")

    keep = x.notna() & q.notna()
    df2 = df.loc[keep].copy()
    if len(df2) < 10:
        raise RuntimeError(f"Too few valid rows ({len(df2)}) after filtering NaNs for metric '{metric_col}'")

    X = torch.tensor(df2[metric_col].to_numpy(dtype=np.float32)).view(-1, 1)
    Q = torch.tensor(df2["Q"].to_numpy(dtype=np.float32)).view(-1, 1)
    return X, Q


# ---------------------- CLI ----------------------
def parse_args():
    p = argparse.ArgumentParser(description="Train AM model(s) from metrics CSV + pred JSON using frozen fc/fb.")
    p.add_argument("--metrics_csv", required=True, help="CSV with columns: sequence, codec, crf, <metric cols...>")
    p.add_argument("--pred_json", required=True, help="pred JSON from EM: pred[seq][codec_crf]=Q")
    p.add_argument("--fc_ckpt", required=True, help="Path to fc checkpoint (.pt)")
    p.add_argument("--fb_ckpt", required=True, help="Path to fb checkpoint (.pt)")
    p.add_argument("--metric_cols", nargs="+", required=True, help="One or more metric column names to train AM for")
    p.add_argument("--out_dir", required=True, help="Directory to save AM checkpoints")

    p.add_argument("--z_dim", type=int, default=5)
    p.add_argument("--latent_dim", type=int, default=1)
    p.add_argument("--fc_out_dim", type=int, default=1)
    p.add_argument("--fc_hidden", nargs="+", type=int, default=[128, 64, 32, 16])


    p.add_argument("--hidden_am", type=int, default=32)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--pairs_per_epoch", type=int, default=4000)
    p.add_argument("--batch_size", type=int, default=512)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--weight_decay", type=float, default=1e-5)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--device", default=None, help="cuda | cpu (default: auto)")

    p.add_argument("--z_lows", nargs=5, type=float, default=[1, 200, 0, 0, 8])
    p.add_argument("--z_highs", nargs=5, type=float, default=[15, 700, 1, 1, 12])

    return p.parse_args()


def main():
    args = parse_args()
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    metrics = pd.read_csv(args.metrics_csv)
    pred = load_json(Path(args.pred_json))

    good_metrics = build_good_metrics(metrics, pred)
    print(f"Aligned rows: {len(good_metrics)} / {len(metrics)}")

    # Load frozen fc/fb
    fc, fb = load_fc_fb(
        Path(args.fc_ckpt),
        Path(args.fb_ckpt),
        z_dim=args.z_dim,
        latent_dim=args.latent_dim,
        out_dim=args.fc_out_dim,
        hidden_sizes=list(args.fc_hidden),
        map_location="cpu",
    )

    # z sampler
    z_limits = ZLimits(lows=tuple(args.z_lows), highs=tuple(args.z_highs))


    for metric_col in args.metric_cols:
        print("\n" + "=" * 80)
        print(f"Training AM for metric: {metric_col}")

        X, Q = tensors_from_metric(good_metrics, metric_col)

        am = train_am_with_fc_fb(
            X, Q, fc=fc, fb=fb,
            z_limits=z_limits,
            hidden_am=args.hidden_am,
            epochs=args.epochs,
            pairs_per_epoch=args.pairs_per_epoch,
            batch_size=args.batch_size,
            lr=args.lr,
            weight_decay=args.weight_decay,
            seed=args.seed,
            device=args.device,
        )

        # Save checkpoint
        safe_name = "".join([c if c.isalnum() or c in ("-", "_") else "_" for c in metric_col])[:120]
        ckpt_path = out_dir / f"am_{safe_name}.pt"

        ckpt = {
            "metric_col": metric_col,
            "am_state_dict": am.state_dict(),
            "am_config": {
                "x_dim": 1,
                "z_dim": 5,
                "hidden_am": args.hidden_am,
            },
            "train_config": {
                "epochs": args.epochs,
                "pairs_per_epoch": args.pairs_per_epoch,
                "batch_size": args.batch_size,
                "lr": args.lr,
                "weight_decay": args.weight_decay,
                "seed": args.seed,
                "device": args.device,
                "z_lows": list(args.z_lows),
                "z_highs": list(args.z_highs),
            },
            "fc_fb_config": {
                "z_dim": args.z_dim,
                "latent_dim": args.latent_dim,
                "fc_out_dim": args.fc_out_dim,
                "fc_hidden": list(args.fc_hidden),
                "fc_ckpt": str(Path(args.fc_ckpt)),
                "fb_ckpt": str(Path(args.fb_ckpt)),
            },
            "data_stats": {
                "N": int(X.shape[0]),
                "X_min": float(X.min().item()),
                "X_max": float(X.max().item()),
                "Q_min": float(Q.min().item()),
                "Q_max": float(Q.max().item()),
            },
        }

        torch.save(ckpt, ckpt_path)
        print(f"Saved: {ckpt_path}")

    print("\nDone.")


if __name__ == "__main__":
    main()
