import argparse
import csv
import json
import os

import numpy as np
import torch
import torch.nn as nn

import models
from dataset_CUBICC import CUBICCDataset
from utils import unpack_data, get_mean

class _LabelEncoder:
    def __init__(self, classes):
        self.classes_ = list(classes)
        self._lookup = {int(c): i for i, c in enumerate(self.classes_)}


def build_label_encoder(dataset):
    labels = dataset.labels
    if torch.is_tensor(labels):
        labels = labels.detach().cpu().numpy()
    labels = np.asarray(labels).astype(int)
    classes = sorted(set(labels.tolist()))
    return _LabelEncoder(classes), len(classes)


def encode_labels(labels, encoder):
    if torch.is_tensor(labels):
        labels = labels.detach().cpu().numpy()
    labels = np.asarray(labels).astype(int)
    if not encoder.classes_:
        return labels
    if set(labels.tolist()) == set(encoder.classes_):
        return labels
    map_fn = np.vectorize(lambda v: encoder._lookup[int(v)])
    return map_fn(labels)


def get_split_dataset(dataset, split_name):
    if split_name == "train":
        idxs = dataset.train_split
    elif split_name == "val":
        idxs = dataset.validation_split
    elif split_name == "test":
        idxs = dataset.test_split
    else:
        raise ValueError("Unknown split: {}".format(split_name))
    return torch.utils.data.Subset(dataset, idxs)


class SimpleCUBICCClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        feats = self.features(x)
        feats = feats.view(feats.size(0), -1)
        return self.classifier(feats)


def resolve_device(args):
    if args.no_cuda or not torch.cuda.is_available():
        return torch.device("cpu")
    if args.cuda_device:
        return torch.device(args.cuda_device)
    return torch.device("cuda")


def build_parser():
    parser = argparse.ArgumentParser(description="Evaluate disentanglement metrics on CUBICC.")
    parser.add_argument("--save-dir", type=str, required=True, help="Path to the trained run directory (contains args.rar and model_*.rar).")
    parser.add_argument( "--epoch", type=int, default=None, help="Epoch checkpoint to load (e.g. 300 loads model_300.rar). Defaults to latest.")
    parser.add_argument("--datadir", type=str, default="../data", help="Directory where CUBICC data is stored (overrides args.rar).")
    parser.add_argument("--vocab-path", type=str, default="data/CUBICC/cub.vocab", help="Path to CUBICC vocab file (e.g. data/CUBICC/cub.vocab).")
    parser.add_argument( "--split", type=str, default="test", choices=["train", "val", "test"], help="Dataset split to evaluate on.")
    parser.add_argument( "--batch-size", type=int, default=None, help="Override batch size from args.rar.")
    parser.add_argument("--model-type", type=str, default=None, choices=["cmvae", "cholderplus"], help="Override model type if not present in args.rar.")
    parser.add_argument( "--num-w-samples", type=int, default=4, help="Number of w samples per datapoint for content stability.")
    parser.add_argument( "--num-z-samples", type=int, default=4, help="Number of z samples per datapoint for w-content accuracy.")
    parser.add_argument( "--classifier-checkpoint", type=str, default=None, help="Path to classifier checkpoint (default: pretrained/cubicc_image_classifier.pt).")
    parser.add_argument( "--train-classifier", action="store_true", default=False, help="Force training the classifier even if a checkpoint exists.")
    parser.add_argument( "--classifier-epochs", type=int, default=10, help="Epochs for classifier training.")
    parser.add_argument( "--classifier-lr", type=float, default=1e-3, help="Learning rate for classifier training.")
    parser.add_argument( "--classifier-batch-size", type=int, default=64, help="Batch size for classifier training.")
    parser.add_argument( "--no-cuda", action="store_true", default=False, help="Disable CUDA.")
    parser.add_argument( "--cuda-device", type=str, default="", help="Override cuda device (example: cuda or cuda:1).")
    parser.add_argument( "--out-dir", type=str, default=None, help="Output directory for metrics (default: <save-dir>/disentanglement_metrics).")
    return parser


def _find_latest_epoch(run_path):
    epochs = []
    for name in os.listdir(run_path):
        if not name.startswith("model_") or not name.endswith(".rar"):
            continue
        epoch_str = name[len("model_") : -len(".rar")]
        if epoch_str.isdigit():
            epochs.append(int(epoch_str))
    return max(epochs) if epochs else None


def _get_latents_for_mod(qu_xs, uss, mod, latent_dim_w, latent_dim_z):
    us = uss[mod]
    if us.dim() == 3:
        us = us[0]
    w, z = torch.split(us, [latent_dim_w, latent_dim_z], dim=-1)
    return w, z


def _sample_w_prior(model, num_samples, batch_size, device):
    pw = model.pw(*model.pw_params())
    w = pw.rsample(torch.Size([num_samples, batch_size]))
    latent_dim_w = model.params.latent_dim_w
    w = w.reshape(num_samples, batch_size, latent_dim_w)
    return w.to(device)


def _sample_w_from_data(w_base, num_samples):
    batch_size = w_base.size(0)
    w_samples = []
    for _ in range(num_samples):
        perm = torch.randperm(batch_size, device=w_base.device)
        w_samples.append(w_base[perm])
    return torch.stack(w_samples, dim=0)


def _sample_z_prior(model, num_samples, batch_size, device):
    if hasattr(model, "pc_params"):
        pc = model.pc_params.squeeze(0)
        idxs = torch.distributions.Categorical(probs=pc).sample((num_samples, batch_size))
        z_samples = []
        for k in range(num_samples):
            z_batch = []
            for b in range(batch_size):
                idx = int(idxs[k, b].item())
                pz = model.pz(*model.pz_params(idx))
                z = pz.rsample(torch.Size([1]))
                if z.dim() == 3:
                    z = z.squeeze(0)
                if z.dim() == 2 and z.size(0) == 1:
                    z = z.squeeze(0)
                z_batch.append(z)
            z_batch = torch.stack(z_batch, dim=0)
            if z_batch.dim() == 3 and z_batch.size(1) == 1:
                z_batch = z_batch.squeeze(1)
            z_samples.append(z_batch)
        return torch.stack(z_samples, dim=0).to(device)
    pz = model.pz(*model.pz_params)
    z_samples = pz.rsample(torch.Size([num_samples, batch_size]))
    if z_samples.dim() == 4:
        z_samples = z_samples.squeeze(2)
    return z_samples.to(device)


def _pairwise_agreement(preds, num_classes):
    k, batch_size = preds.size()
    if k <= 1:
        return torch.ones(batch_size, device=preds.device)
    preds_t = preds.transpose(0, 1)
    counts = torch.zeros((batch_size, num_classes), device=preds.device, dtype=torch.long)
    counts.scatter_add_(1, preds_t, torch.ones_like(preds_t, dtype=torch.long))
    numer = (counts * (counts - 1)).sum(dim=1).float()
    denom = float(k * (k - 1))
    return numer / denom


def _train_image_classifier(classifier, train_loader, eval_loader, device, label_encoder, epochs, lr):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(classifier.parameters(), lr=lr)
    classifier.train()
    for _ in range(epochs):
        for dataT in train_loader:
            data, labels_b = unpack_data(dataT, device=device)
            labels_np = encode_labels(labels_b, label_encoder)
            labels = torch.from_numpy(labels_np).to(device=device, dtype=torch.long)
            imgs = data[0]
            optimizer.zero_grad()
            logits = classifier(imgs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
    return _eval_image_classifier(classifier, eval_loader, device, label_encoder)


def _eval_image_classifier(classifier, loader, device, label_encoder):
    classifier.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for dataT in loader:
            data, labels_b = unpack_data(dataT, device=device)
            labels_np = encode_labels(labels_b, label_encoder)
            labels = torch.from_numpy(labels_np).to(device=device, dtype=torch.long)
            imgs = data[0]
            preds = classifier(imgs).argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / max(total, 1)


def _append_metrics_json(path, payload):
    if os.path.exists(path):
        try:
            with open(path, "r") as f:
                existing = json.load(f)
        except json.JSONDecodeError:
            existing = []
    else:
        existing = []
    if isinstance(existing, dict):
        existing = [existing]
    if not isinstance(existing, list):
        existing = []
    existing.append(payload)
    with open(path, "w") as f:
        json.dump(existing, f, indent=2, sort_keys=True)


def compute_z_content_stability(model, classifier, loader, device, label_encoder, num_w_samples, mod_idx, decode_mod=0):
    model.eval()
    classifier.eval()
    total_agreement = 0.0
    total = 0
    num_classes = len(label_encoder.classes_)
    with torch.no_grad():
        for dataT in loader:
            data, _ = unpack_data(dataT, device=device)
            outputs = model(data, K=1)
            qu_xs, _, uss = outputs[:3]
            w_base, z = _get_latents_for_mod(qu_xs, uss, mod_idx, model.params.latent_dim_w, model.params.latent_dim_z)
            batch_size = z.size(0)
            w_samples = _sample_w_prior(model, num_w_samples, batch_size, device)
            z_samples = z.unsqueeze(0).expand(num_w_samples, batch_size, -1)
            u = torch.cat((w_samples, z_samples), dim=-1)
            px_u = model.vaes[decode_mod].px_u(*model.vaes[decode_mod].dec(u))
            imgs = get_mean(px_u)
            imgs = imgs.view(num_w_samples * batch_size, *imgs.size()[2:])
            preds = classifier(imgs).argmax(dim=1)
            preds = preds.view(num_w_samples, batch_size)
            agreement = _pairwise_agreement(preds, num_classes)
            total_agreement += agreement.sum().item()
            total += batch_size
    return total_agreement / max(total, 1)


def compute_w_content_accuracy(model, classifier, loader, device, label_encoder, num_z_samples, mod_idx, decode_mod=0):
    model.eval()
    classifier.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for dataT in loader:
            data, labels_b = unpack_data(dataT, device=device)
            labels_np = encode_labels(labels_b, label_encoder)
            labels = torch.from_numpy(labels_np).to(device=device, dtype=torch.long)
            outputs = model(data, K=1)
            qu_xs, _, uss = outputs[:3]
            w_base, _ = _get_latents_for_mod(qu_xs, uss, mod_idx, model.params.latent_dim_w, model.params.latent_dim_z)
            batch_size = w_base.size(0)
            z_samples = _sample_z_prior(model, num_z_samples, batch_size, device)
            if z_samples.dim() == 4 and z_samples.size(2) == 1:
                z_samples = z_samples.squeeze(2)
            w_samples = w_base.unsqueeze(0).expand(num_z_samples, batch_size, -1)
            u = torch.cat((w_samples, z_samples), dim=-1)
            px_u = model.vaes[decode_mod].px_u(*model.vaes[decode_mod].dec(u))
            imgs = get_mean(px_u)
            imgs = imgs.view(num_z_samples * batch_size, *imgs.size()[2:])
            preds = classifier(imgs).argmax(dim=1)
            labels_rep = labels.unsqueeze(0).expand(num_z_samples, batch_size).reshape(-1)
            correct += (preds == labels_rep).sum().item()
            total += labels_rep.size(0)
    return correct / max(total, 1)


def compute_z_content_accuracy(model, classifier, loader, device, label_encoder, num_w_samples, mod_idx, decode_mod=0):
    model.eval()
    classifier.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for dataT in loader:
            data, labels_b = unpack_data(dataT, device=device)
            labels_np = encode_labels(labels_b, label_encoder)
            labels = torch.from_numpy(labels_np).to(device=device, dtype=torch.long)
            outputs = model(data, K=1)
            qu_xs, _, uss = outputs[:3]
            w_base, z = _get_latents_for_mod(qu_xs, uss, mod_idx, model.params.latent_dim_w, model.params.latent_dim_z)
            batch_size = z.size(0)
            w_samples = _sample_w_prior(model, num_w_samples, batch_size, device)
            z_samples = z.unsqueeze(0).expand(num_w_samples, batch_size, -1)
            u = torch.cat((w_samples, z_samples), dim=-1)
            px_u = model.vaes[decode_mod].px_u(*model.vaes[decode_mod].dec(u))
            imgs = get_mean(px_u)
            imgs = imgs.view(num_w_samples * batch_size, *imgs.size()[2:])
            preds = classifier(imgs).argmax(dim=1)
            labels_rep = labels.unsqueeze(0).expand(num_w_samples, batch_size).reshape(-1)
            correct += (preds == labels_rep).sum().item()
            total += labels_rep.size(0)
    return correct / max(total, 1)


def main():
    parser = build_parser()
    cmds = parser.parse_args()

    run_path = cmds.save_dir
    args_path = os.path.join(run_path, "args.rar")
    try:
        args = torch.load(args_path, weights_only=False)
    except TypeError:
        args = torch.load(args_path)

    if cmds.batch_size is not None:
        args.batch_size = cmds.batch_size
    if cmds.datadir is not None:
        args.datadir = cmds.datadir
    if cmds.vocab_path:
        args.datadir = os.path.dirname(os.path.dirname(cmds.vocab_path))
    args.no_cuda = cmds.no_cuda
    if not hasattr(args, "latent_dim_u") and hasattr(args, "latent_dim_w") and hasattr(args, "latent_dim_z"):
        args.latent_dim_u = args.latent_dim_w + args.latent_dim_z

    model_type = cmds.model_type or getattr(args, "model_type", "cmvae")

    device = resolve_device(cmds)

    if model_type == "cholderplus":
        model_cls = models.CHolderplus
    else:
        model_cls = models.CMVAE
    model = models.CUB_Image_Sentence(args, model_cls=model_cls).to(device)
    model.eval()

    if cmds.epoch is not None:
        epoch = cmds.epoch
    else:
        epoch = 300 if os.path.exists(os.path.join(run_path, "model_300.rar")) else None
        if epoch is None:
            raise ValueError("model_300.rar not found in {} (pass --epoch to override)".format(run_path))
    checkpoint_path = os.path.join(run_path, "model_{}.rar".format(epoch))
    if not os.path.exists(checkpoint_path):
        raise ValueError("Checkpoint not found: {}".format(checkpoint_path))
    model.load_state_dict(torch.load(checkpoint_path, map_location=device), strict=False)

    dataset = CUBICCDataset(datadir=os.path.join(args.datadir, "CUBICC"))
    train_dataset = get_split_dataset(dataset, "train")
    eval_dataset = get_split_dataset(dataset, cmds.split)
    classifier_eval_dataset = get_split_dataset(dataset, "val")
    kwargs = {"num_workers": 2, "pin_memory": True} if device.type == "cuda" else {}
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cmds.classifier_batch_size, shuffle=True, **kwargs)
    eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False, **kwargs)
    classifier_eval_loader = torch.utils.data.DataLoader(classifier_eval_dataset, batch_size=cmds.classifier_batch_size, shuffle=False, **kwargs)

    label_encoder, num_classes = build_label_encoder(dataset)

    out_dir = cmds.out_dir or os.path.join(run_path, "disentanglement_metrics")
    os.makedirs(out_dir, exist_ok=True)
    if cmds.classifier_checkpoint:
        classifier_path = cmds.classifier_checkpoint
    else:
        repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        classifier_path = os.path.join(repo_root, "pretrained", "cubicc_image_classifier.pt")
    classifier_dir = os.path.dirname(classifier_path)
    if classifier_dir:
        os.makedirs(classifier_dir, exist_ok=True)

    classifier = SimpleCUBICCClassifier(num_classes).to(device)
    classifier_acc = None
    if os.path.exists(classifier_path) and not cmds.train_classifier:
        classifier.load_state_dict(torch.load(classifier_path, map_location=device))
        classifier_acc = _eval_image_classifier(classifier, classifier_eval_loader, device, label_encoder)
    elif cmds.train_classifier:
        classifier_acc = _train_image_classifier(classifier, train_loader, classifier_eval_loader, device, label_encoder, cmds.classifier_epochs, cmds.classifier_lr)
        torch.save(classifier.state_dict(), classifier_path)
    else:
        raise FileNotFoundError("Classifier checkpoint not found: {} (use --train-classifier to create it)".format(classifier_path))

    by_modality = {}
    for mod_idx in range(len(model.vaes)):
        z_content_stability = compute_z_content_stability(
            model,
            classifier,
            eval_loader,
            device,
            label_encoder,
            cmds.num_w_samples,
            mod_idx,
            decode_mod=0,
        )
        w_content_accuracy = compute_w_content_accuracy(
            model,
            classifier,
            eval_loader,
            device,
            label_encoder,
            cmds.num_z_samples,
            mod_idx,
            decode_mod=0,
        )
        z_content_accuracy = compute_z_content_accuracy(
            model,
            classifier,
            eval_loader,
            device,
            label_encoder,
            cmds.num_w_samples,
            mod_idx,
            decode_mod=0,
        )
        by_modality["m{}".format(mod_idx)] = {
            "z_content_stability": float(z_content_stability),
            "w_content_accuracy": float(w_content_accuracy),
            "z_content_accuracy": float(z_content_accuracy),
        }
    mean_z_content_stability = float(np.mean([metrics["z_content_stability"] for metrics in by_modality.values()])) if by_modality else 0.0
    mean_w_content_accuracy = float(np.mean([metrics["w_content_accuracy"] for metrics in by_modality.values()])) if by_modality else 0.0
    mean_z_content_accuracy = float(np.mean([metrics["z_content_accuracy"] for metrics in by_modality.values()])) if by_modality else 0.0

    payload = {
        "run_dir": run_path,
        "epoch": epoch,
        "split": cmds.split,
        "model_type": model_type,
        "num_w_samples": cmds.num_w_samples,
        "num_z_samples": cmds.num_z_samples,
        "by_modality": by_modality,
        "mean_z_content_stability": mean_z_content_stability,
        "mean_w_content_accuracy": mean_w_content_accuracy,
        "mean_z_content_accuracy": mean_z_content_accuracy,
        "classifier_eval_accuracy": float(classifier_acc),
        "classifier_checkpoint": classifier_path,
    }
    json_path = os.path.join(out_dir, "disentanglement_metrics.json")
    _append_metrics_json(json_path, payload)

    csv_path = os.path.join(out_dir, "disentanglement_metrics.csv")
    write_header = not os.path.exists(csv_path) or os.path.getsize(csv_path) == 0
    with open(csv_path, "a", newline="") as f:
        writer = csv.writer(f)
        if write_header:
            writer.writerow(
                [
                    "epoch",
                    "split",
                    "num_w_samples",
                    "num_z_samples",
                    "modality",
                    "z_content_stability",
                    "w_content_accuracy",
                    "z_content_accuracy",
                    "classifier_eval_accuracy",
                ]
            )
        for mod_key, metrics in sorted(by_modality.items()):
            writer.writerow(
                [
                    epoch,
                    cmds.split,
                    cmds.num_w_samples,
                    cmds.num_z_samples,
                    mod_key,
                    metrics["z_content_stability"],
                    metrics["w_content_accuracy"],
                    metrics["z_content_accuracy"],
                    classifier_acc,
                ]
            )
        writer.writerow(
            [
                epoch,
                cmds.split,
                cmds.num_w_samples,
                cmds.num_z_samples,
                "mean",
                mean_z_content_stability,
                mean_w_content_accuracy,
                mean_z_content_accuracy,
                classifier_acc,
            ]
        )

    print("Saved disentanglement metrics to:", out_dir)


if __name__ == "__main__":
    main()
