import argparse
import json
import random
import re
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


eps = 1e-8


class MethodsDFDataset(Dataset):
    def __init__(self, df, z_cols=None, method2idx=None):
        self.df = df.reset_index(drop=True)

        if z_cols is None:
            z_cols = ["h_ph", "dpi", "dtype", "prc_br", "bit-depth"]
        self.z_cols = z_cols

        if method2idx is None:
            unique_methods = pd.unique(pd.concat([df["left_method"], df["right_method"]]))
            self.method2idx = {m: i for i, m in enumerate(unique_methods)}
        else:
            self.method2idx = method2idx

        self.left_idx = df["left_method"].map(self.method2idx).astype(int).values
        self.right_idx = df["right_method"].map(self.method2idx).astype(int).values
        self.y = df["ans"].values.astype(np.float32)

        z = df[self.z_cols].apply(pd.to_numeric, errors="coerce").fillna(0.0)
        self.z = z.values.astype(np.float32)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        return (
            int(self.left_idx[idx]),
            int(self.right_idx[idx]),
            torch.tensor(self.y[idx], dtype=torch.float32),
            torch.from_numpy(self.z[idx]),
        )


def init_weights_xavier(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)


class FCNet(nn.Module):
    def __init__(self, z_dim, latent_dim, out_dim, hidden_sizes):
        super().__init__()
        layers = []
        input_dim = z_dim + latent_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(input_dim, h))
            layers.append(nn.Tanh())
            input_dim = h
        layers.append(nn.Linear(input_dim, out_dim))
        self.net = nn.Sequential(*layers)
        self.apply(init_weights_xavier)

    def forward(self, z, xlatent):
        inp = torch.cat([z, xlatent], dim=1)
        return self.net(inp)


def compute_u(f_c, f_b, z, xi, xj, eps=1e-8):
    a = f_c(z, xi)
    b = f_b(z, xj)
    c = f_c(z, xj)
    d = f_b(z, xi)
    r1 = a - b
    r2 = c - d
    n1 = torch.sqrt(torch.sum(r1 * r1, dim=1) + eps)
    n2 = torch.sqrt(torch.sum(r2 * r2, dim=1) + eps)
    return n1 - n2


def latent_prior_loss_for_indices(embeddings, idx_left, idx_right, lam):
    w = embeddings.weight
    xi = w[idx_left]
    xj = w[idx_right]
    return (lam / 2.0) * (xi.pow(2).sum(dim=1).mean() + xj.pow(2).sum(dim=1).mean())


def run_em_with_df(
    df,
    *,
    device,
    seed=42,
    z_dim=5,
    latent_dim=1,
    feature_out=1,
    hidden_sizes=(128, 64, 32, 16),
    batch_size=512,
    E_steps=30,
    alpha_x=5e-3,
    M_epochs_per_em=1,
    M_lr=1e-3,
    weight_decay=1e-5,
    num_em_iters=100,
    prior_lambda=1e-3,
):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    ds = MethodsDFDataset(df)
    dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)
    num_methods = len(ds.method2idx)
    print(f"Num methods: {num_methods}")

    f_c = FCNet(z_dim, latent_dim, feature_out, list(hidden_sizes)).to(device)
    f_b = FCNet(z_dim, latent_dim, feature_out, list(hidden_sizes)).to(device)

    method_emb = nn.Embedding(num_methods, latent_dim).to(device)
    nn.init.normal_(method_emb.weight, mean=0.0, std=0.01)

    opt_theta = torch.optim.Adam(
        list(f_c.parameters()) + list(f_b.parameters()),
        lr=M_lr,
        weight_decay=weight_decay,
    )
    opt_emb = torch.optim.Adam([method_emb.weight], lr=alpha_x)
    mse_loss = nn.MSELoss()

    for em in range(num_em_iters):
        # E-step
        f_c.eval()
        f_b.eval()
        for x_left, x_right, y_batch, z_batch in dataloader:
            x_left = x_left.to(device).long()
            x_right = x_right.to(device).long()
            y_batch = y_batch.to(device).float()
            z_batch = z_batch.to(device).float()

            for _ in range(E_steps):
                opt_emb.zero_grad(set_to_none=True)
                xi = method_emb(x_left)
                xj = method_emb(x_right)
                u = compute_u(f_c, f_b, z_batch, xi, xj, eps=eps)
                pred = torch.sigmoid(u)
                target = (y_batch + 1) / 2
                loss_e = mse_loss(pred, target)
                loss_e += latent_prior_loss_for_indices(method_emb, x_left, x_right, prior_lambda)
                loss_e.backward()
                opt_emb.step()

        # M-step
        f_c.train()
        f_b.train()
        epoch_loss = 0.0
        n_total = 0
        for _ in range(M_epochs_per_em):
            for x_left, x_right, y_batch, z_batch in dataloader:
                x_left = x_left.to(device).long()
                x_right = x_right.to(device).long()
                y_batch = y_batch.to(device).float()
                z_batch = z_batch.to(device).float()

                with torch.no_grad():
                    xi = method_emb(x_left).detach()
                    xj = method_emb(x_right).detach()

                opt_theta.zero_grad(set_to_none=True)
                u = compute_u(f_c, f_b, z_batch, xi, xj, eps=eps)
                pred = torch.sigmoid(u)
                target = (y_batch + 1) / 2
                loss_m = mse_loss(pred, target)
                loss_m.backward()
                opt_theta.step()

                epoch_loss += loss_m.item() * z_batch.size(0)
                n_total += z_batch.size(0)

        avg_loss = epoch_loss / max(1, n_total)
        print(f"EM iter {em+1}/{num_em_iters}, M-step avg MSE loss: {avg_loss:.6f}")

    return f_c, f_b, method_emb, ds.method2idx


# ---------------------- CLI helpers ----------------------
def parse_args():
    p = argparse.ArgumentParser(description="Compute embeddings via EM from a CSV of votes (save ckpts by path).")
    p.add_argument("--input", required=True, help="Path to input CSV (votes + phone info).")

    p.add_argument("--pred_json", required=True, help="Output pred JSON path (e.g., pred.json)")
    p.add_argument("--fc_ckpt", required=True, help="Output fc checkpoint path (e.g., fc.pt)")
    p.add_argument("--fb_ckpt", required=True, help="Output fb checkpoint path (e.g., fb.pt)")
    p.add_argument("--emb_ckpt", required=True, help="Output method embedding checkpoint path (e.g., method_emb.pt)")

    p.add_argument("--device", default=None, help="cuda|cpu (default: auto)")
    p.add_argument("--seed", type=int, default=42)

    p.add_argument("--batch_size", type=int, default=512)
    p.add_argument("--num_em_iters", type=int, default=100)
    p.add_argument("--E_steps", type=int, default=30)
    p.add_argument("--alpha_x", type=float, default=5e-3)
    p.add_argument("--M_epochs_per_em", type=int, default=1)
    p.add_argument("--M_lr", type=float, default=1e-3)
    p.add_argument("--weight_decay", type=float, default=1e-5)
    p.add_argument("--prior_lambda", type=float, default=1e-3)

    p.add_argument("--z_dim", type=int, default=5)
    p.add_argument("--latent_dim", type=int, default=1)
    p.add_argument("--feature_out", type=int, default=1)
    p.add_argument("--hidden_sizes", nargs="+", type=int, default=[128, 64, 32, 16])

    return p.parse_args()


def infer_display_dtype(displaytype_val) -> int:
    if displaytype_val is None:
        return 0
    s = str(displaytype_val)
    matches = re.findall(r"LED|IPS|LCD", s, flags=re.IGNORECASE)
    for m in matches:
        u = m.upper()
        if u in ("IPS", "LCD"):
            return 0
        if u == "LED":
            return 1
    return 0


def build_ans(row) -> int:
    ans = row.get("answer", None)
    if ans is None:
        return 0
    if ans == row.get("left_method", None):
        return -1
    return 1


def ensure_schema_compat(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()

    if "test_case" in df.columns and "video" not in df.columns:
        df = df.rename(columns={"test_case": "video"})

    if "displaytype" in df.columns and "Display type" not in df.columns:
        df = df.rename(columns={"displaytype": "Display type"})

    required = ["left_method", "right_method", "video"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise KeyError(f"Missing required columns: {missing}")

    if "prc_br" not in df.columns:
        if "brightness_mean" in df.columns:
            df["prc_br"] = pd.to_numeric(df["brightness_mean"], errors="coerce")
        elif "brightness_max" in df.columns:
            df["prc_br"] = pd.to_numeric(df["brightness_max"], errors="coerce")
        else:
            df["prc_br"] = 0.0

    return df


def _mkdir_for_file(path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)


def main():
    args = parse_args()
    device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))

    df = pd.read_csv(args.input)
    df = df.where(pd.notna(df), None)
    df = ensure_schema_compat(df)

    if "illumination_max" in df.columns:
        df["log_lum"] = np.log1p(pd.to_numeric(df["illumination_max"], errors="coerce"))
    else:
        df["log_lum"] = 0.0

    df["dtype"] = [infer_display_dtype(v) for v in df.get("Display type", pd.Series([None] * len(df)))]
    df["ans"] = [build_ans(r) for _, r in df.iterrows()]

    # Append _video suffix if missing
    subdf = df.copy()
    left_s = subdf["left_method"].astype(str)
    right_s = subdf["right_method"].astype(str)
    vid_s = subdf["video"].astype(str)

    need_suffix_left = ~(left_s.str.endswith("_" + vid_s))
    need_suffix_right = ~(right_s.str.endswith("_" + vid_s))

    subdf.loc[need_suffix_left, "left_method"] = left_s[need_suffix_left] + "_" + vid_s[need_suffix_left]
    subdf.loc[need_suffix_right, "right_method"] = right_s[need_suffix_right] + "_" + vid_s[need_suffix_right]

    f_c_model, f_b_model, method_emb, method2idx = run_em_with_df(
        subdf,
        device=device,
        seed=args.seed,
        z_dim=args.z_dim,
        latent_dim=args.latent_dim,
        feature_out=args.feature_out,
        hidden_sizes=args.hidden_sizes,
        batch_size=args.batch_size,
        E_steps=args.E_steps,
        alpha_x=args.alpha_x,
        M_epochs_per_em=args.M_epochs_per_em,
        M_lr=args.M_lr,
        weight_decay=args.weight_decay,
        num_em_iters=args.num_em_iters,
        prior_lambda=args.prior_lambda,
    )

    # pred[video][codec_crf] = Q
    pred = {}
    for k in method2idx.keys():
        parts = k.split("_")
        if len(parts) < 3:
            seq = "unknown"
            kk = k
        else:
            seq = parts[2]
            kk = "_".join(parts[:2])
        pred.setdefault(seq, {})
        pred[seq][kk] = float(method_emb.weight[method2idx[k]].detach().cpu().numpy()[0])

    pred_path = Path(args.pred_json)
    _mkdir_for_file(pred_path)
    with pred_path.open("w", encoding="utf-8") as f:
        json.dump(pred, f, ensure_ascii=False, indent=2)
    print(f"Saved pred JSON: {pred_path}")

    cfg = {
        "z_dim": args.z_dim,
        "latent_dim": args.latent_dim,
        "feature_out": args.feature_out,
        "hidden_sizes": list(args.hidden_sizes),
        "seed": args.seed,
        "em_hparams": {
            "batch_size": args.batch_size,
            "num_em_iters": args.num_em_iters,
            "E_steps": args.E_steps,
            "alpha_x": args.alpha_x,
            "M_epochs_per_em": args.M_epochs_per_em,
            "M_lr": args.M_lr,
            "weight_decay": args.weight_decay,
            "prior_lambda": args.prior_lambda,
        },
    }

    fc_path = Path(args.fc_ckpt)
    fb_path = Path(args.fb_ckpt)
    emb_path = Path(args.emb_ckpt)
    _mkdir_for_file(fc_path)
    _mkdir_for_file(fb_path)
    _mkdir_for_file(emb_path)

    torch.save({"state_dict": f_c_model.state_dict(), "config": cfg}, fc_path)
    torch.save({"state_dict": f_b_model.state_dict(), "config": cfg}, fb_path)
    torch.save(
        {
            "state_dict": method_emb.state_dict(),
            "method2idx": method2idx,
            "config": {"latent_dim": args.latent_dim, "seed": args.seed},
        },
        emb_path,
    )

    print(f"Saved fc ckpt: {fc_path}")
    print(f"Saved fb ckpt: {fb_path}")
    print(f"Saved emb ckpt: {emb_path}")


if __name__ == "__main__":
    main()
