import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
import random
from pathlib import Path

# --- t-SNE / plotting ---
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

# --- Assuming these exist in your 'utils.py' ---
from utils import (
    get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset,
    get_daparam, match_loss, get_time, TensorDataset, epoch,
    DiffAugment, ParamDiffAug
)

# ============================================================
# Feature extraction (penultimate layer) + metadata (y, g)
# ============================================================

def _find_last_linear(model: nn.Module):
    last_linear = None
    for m in model.modules():
        if isinstance(m, nn.Linear):
            last_linear = m
    return last_linear

def get_features_and_metadata(model, dataloader, device):
    """
    Extract penultimate features and also return:
      - y: task label (class)
      - g: sensitive attribute / demographic group
    Expects batch[0]=img, batch[1]=y, batch[2]=g (if exists).
    If g missing, fills zeros.
    """
    model.eval()
    features_list, y_list, g_list = [], [], []

    activation = {}

    def hook_fn(module, inp, out):
        # inp[0] is the feature vector going into the classifier layer
        activation["feat"] = inp[0].detach().cpu().numpy()

    # Attach hook to last classifier-like layer (best-effort)
    handle = None
    if hasattr(model, "linear") and isinstance(model.linear, nn.Module):
        handle = model.linear.register_forward_hook(hook_fn)
    elif hasattr(model, "fc") and isinstance(model.fc, nn.Module):
        handle = model.fc.register_forward_hook(hook_fn)
    elif hasattr(model, "classifier") and isinstance(model.classifier, nn.Module):
        handle = model.classifier.register_forward_hook(hook_fn)
    else:
        last_linear = _find_last_linear(model)
        if last_linear is None:
            print("Warning: Could not find a Linear layer to hook for features.")
            return None, None, None
        handle = last_linear.register_forward_hook(hook_fn)

    with torch.no_grad():
        for batch in dataloader:
            imgs = batch[0]
            y = batch[1] if len(batch) >= 2 else torch.zeros((imgs.size(0),), dtype=torch.long)

            if len(batch) >= 3:
                g = batch[2]
            else:
                g = torch.zeros_like(y)

            imgs = imgs.to(device, non_blocking=True)
            _ = model(imgs)  # triggers hook

            if "feat" not in activation:
                print("Warning: Hook did not capture features for a batch.")
                continue

            features_list.append(activation["feat"])
            y_list.append(y.detach().cpu().numpy())
            g_list.append(g.detach().cpu().numpy())

    if handle is not None:
        handle.remove()

    if len(features_list) == 0:
        return None, None, None

    feats = np.concatenate(features_list, axis=0)
    y_all = np.concatenate(y_list, axis=0)
    g_all = np.concatenate(g_list, axis=0)
    return feats, y_all, g_all


# ============================================================
# t-SNE utilities: SAME embedding, two colorings (group + class)
# ============================================================

def subsample_for_tsne(feats, y, g, max_points=5000, seed=42):
    n = feats.shape[0]
    if n <= max_points:
        idx = np.arange(n, dtype=np.int64)
        return feats, y, g, idx
    rng = np.random.default_rng(seed)
    idx = rng.choice(n, size=max_points, replace=False).astype(np.int64)
    return feats[idx], y[idx], g[idx], idx

def compute_tsne_embedding(feats_subset, seed=42, perplexity=30):
    tsne = TSNE(
        n_components=2,
        random_state=seed,
        perplexity=perplexity,
        init="pca",
        learning_rate="auto",
    )
    emb = tsne.fit_transform(feats_subset)
    return emb

def _pick_cmap(n_unique: int):
    # reasonable defaults across many categories
    if n_unique <= 10:
        return plt.cm.get_cmap("tab10", n_unique)
    if n_unique <= 20:
        return plt.cm.get_cmap("tab20", n_unique)
    return plt.cm.get_cmap("hsv", n_unique)

def plot_embedding_by_label(emb2d, labels, title, save_path, legend_prefix="Label", point_size=10):
    """
    Plot a 2D embedding (already computed) colored by integer labels.
    """
    unique = np.unique(labels)
    cmap = _pick_cmap(len(unique))

    plt.figure(figsize=(8, 6))
    for i, lab in enumerate(unique):
        mask = labels == lab
        plt.scatter(
            emb2d[mask, 0], emb2d[mask, 1],
            color=cmap(i),
            s=point_size,
            alpha=0.7,
            label=f"{legend_prefix} {lab}"
        )

    plt.title(title)
    # keep figure clean like many papers; enable legend if you want
    # plt.legend(loc="best", fontsize=8, frameon=False)
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=300)
    plt.close()
    print(f"  [Plotting] Saved to {save_path}")

def plot_paired_tsne_same_embedding(feats, y, g, title_base, save_dir, file_stem, seed=42, max_points=5000, perplexity=30):
    """
    Computes ONE t-SNE embedding (on a subsample if needed),
    then saves TWO plots using the SAME embedding:
      - colored by group g
      - colored by class y

    Returns emb2d, subsample_indices
    """
    feats_s, y_s, g_s, idx = subsample_for_tsne(feats, y, g, max_points=max_points, seed=seed)
    emb2d = compute_tsne_embedding(feats_s, seed=seed, perplexity=perplexity)

    group_path = os.path.join(save_dir, f"{file_stem}_byGrouplbl.png")
    class_path = os.path.join(save_dir, f"{file_stem}_byClasslbl.png")

    plot_embedding_by_label(
        emb2d, g_s,
        title=f"{title_base} (colored by group)",
        save_path=group_path,
        legend_prefix="Group",
        point_size=10
    )
    plot_embedding_by_label(
        emb2d, y_s,
        title=f"{title_base} (colored by class)",
        save_path=class_path,
        legend_prefix="Class",
        point_size=10
    )

    return emb2d, idx


# ============================================================
# Dump saving (now includes t-SNE embedding + indices)
# ============================================================

def save_tsne_dump(save_path, feats, y, g, meta, tsne_emb2d=None, tsne_indices=None, tsne_params=None):
    """
    Save everything needed for later plotting WITHOUT GPU.
    Includes:
      - feats [N,D]
      - y [N]
      - g [N]
      - meta dict
      - tsne (optional): embedding + indices + params
    """
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    dump = {
        "feats": torch.tensor(feats, dtype=torch.float32),
        "y": torch.tensor(y, dtype=torch.long),
        "g": torch.tensor(g, dtype=torch.long),
        "meta": meta,
    }

    if tsne_emb2d is not None and tsne_indices is not None:
        dump["tsne"] = {
            "emb2d": torch.tensor(tsne_emb2d, dtype=torch.float32),      # [M,2]
            "indices": torch.tensor(tsne_indices, dtype=torch.long),      # [M]
            "params": tsne_params or {},
        }

    torch.save(dump, save_path)
    print(f"  [Dump] Saved plot-ready dump to: {save_path}")


# ============================================================
# Main
# ============================================================

def main():
    parser = argparse.ArgumentParser(description="Parameter Processing")
    parser.add_argument("--dataset", type=str, default="CIFAR10", help="dataset")
    parser.add_argument("--model", type=str, default="ConvNet", help="model")
    parser.add_argument("--ipc", type=int, default=50, help="image(s) per class")
    parser.add_argument("--eval_mode", type=str, default="S", help="eval_mode")
    parser.add_argument("--num_exp", type=int, default=1, help="the number of experiments")
    parser.add_argument("--num_eval", type=int, default=5, help="the number of evaluating randomly initialized models")
    parser.add_argument("--epoch_eval_train", type=int, default=1000, help="epochs to train a model with synthetic data")
    parser.add_argument("--lr_img", type=float, default=1.0, help="learning rate for updating synthetic images")
    parser.add_argument("--lr_net", type=float, default=0.01, help="learning rate for updating network parameters")
    parser.add_argument("--batch_real", type=int, default=256, help="batch size for real data")
    parser.add_argument("--batch_train", type=int, default=256, help="batch size for training networks")
    parser.add_argument("--init", type=str, default="real", help="noise/real init for synthetic images")
    parser.add_argument("--dsa_strategy", type=str, default="color_crop_cutout_flip_scale_rotate", help="DiffAug strategy")
    parser.add_argument("--data_path", type=str, default="data", help="dataset path")
    parser.add_argument("--save_path", type=str, default="result", help="path to save results")
    parser.add_argument("--dis_metric", type=str, default="ours", help="distance metric")
    parser.add_argument("--shuffle", type=bool, default=False)
    parser.add_argument("--FairDD", action="store_true", help="Enable FairDD")
    parser.add_argument("--group_balance", type=bool, default=False)
    parser.add_argument("--ALL_data", type=str, default="", help="unused")
    parser.add_argument("--tsne_max_points", type=int, default=5000, help="max points for t-SNE")
    parser.add_argument("--tsne_perplexity", type=int, default=30, help="t-SNE perplexity")
    parser.add_argument("--overwrite_dump", action="store_true", help="overwrite dump even if exists")

    args = parser.parse_args()
    args.outer_loop, args.inner_loop = get_loops(args.ipc)
    args.device = "cuda" if torch.cuda.is_available() else "cpu"
    args.dsa_param = ParamDiffAug()
    args.dsa = True

    # --- Your loops ---
    NAMES = ["DC"]
    ALL_DATA = [
        "Colored_MNIST_foreground",
        # "Colored_MNIST_background",
    ]

    out_dir = "./T-SNE"

    for name in NAMES:
        for dataset in ALL_DATA:
            for fair_crt in ["NoFair", "FairDD", "NoOrtho"]:
                args.testMetric = name

                for ipc in [50]:
                    args.ipc = ipc

                    # keep your original behavior
                    if name == "DC":
                        args.dsa = False
                    else:
                        args.dsa = True
                    args.dsa = False

                    dump_name = f"dump_{name}_{dataset}_ipc{args.ipc}_{fair_crt}lbl.pt"
                    dump_path = os.path.join(out_dir, dump_name)

                    if os.path.exists(dump_path) and not args.overwrite_dump:
                        # If an old dump exists but doesn't have tsne, we'll regenerate
                        try:
                            old = torch.load(dump_path, map_location="cpu", weights_only=False)
                            if isinstance(old, dict) and ("tsne" in old):
                                print(f"[Skip] Dump already exists with t-SNE: {dump_path}")
                                continue
                            else:
                                print(f"[Regen] Dump exists but missing t-SNE field: {dump_path}")
                        except Exception:
                            print(f"[Regen] Could not read existing dump, regenerating: {dump_path}")

                    # ---- Load synthetic checkpoint ----
                    save_path = f"./results-pt/{name}/{name}-{fair_crt}/"
                    if fair_crt == "FairDD":
                        save_path += "FairDD_"
                    elif fair_crt == "NoOrtho":
                        save_path += "Fair_NoOrtho_"

                    save_path += f"{name}_{dataset}_ipc{args.ipc}/"
                    save_path += f"res_{name}_{dataset}_ConvNet_{args.ipc}ipc.pt"

                    checkpoint = torch.load(save_path, map_location=args.device, weights_only=False)

                    try:
                        image_syn, label_syn = checkpoint["data"][0]
                    except Exception:
                        image_syn, label_syn = checkpoint["data"]

                    image_syn = image_syn.to(args.device)
                    label_syn = label_syn.to(args.device)

                    # ---- Real dataset loader ----
                    args.dataset = dataset
                    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(
                        args.dataset, args.data_path
                    )
                    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)

                    load_random_state(random_state)

                    labels_all = [dst_train[i][1] for i in range(len(dst_train))]
                    color_all = [dst_train[i][2] for i in range(len(dst_train))]
                    args.num_classes = len(np.unique(labels_all))
                    args.num_groups = len(np.unique(color_all))

                    model_eval = model_eval_pool[0]
                    print(
                        "-----------------\nEvaluation\nmodel_train = %s, model_eval = %s"
                        % (args.model, model_eval)
                    )

                    args.dc_aug_param = get_daparam(args.dataset, args.model, model_eval, args.ipc)

                    # ---- Train eval net on synthetic set ----
                    net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device)
                    image_syn_eval = copy.deepcopy(image_syn.detach())
                    label_syn_eval = copy.deepcopy(label_syn.detach())

                    image_syn_eval = DiffAugment(image_syn_eval, args.dsa_strategy, seed=seed, param=args.dsa_param)

                    net_eval, *_ = evaluate_synset(
                        1, net_eval, image_syn_eval, label_syn_eval, testloader, args, verbose=False
                    )

                    # =======================================================
                    # Extract feats + y + g ONCE, then one embedding, two plots
                    # =======================================================
                    feats, y, g = get_features_and_metadata(net_eval, testloader, args.device)
                    if feats is None:
                        continue

                    file_stem = f"tsne_{name}_{dataset}_ipc{args.ipc}_{fair_crt}"
                    title_base = f"t-SNE: {name} on {dataset} (ipc={args.ipc}, {fair_crt})"

                    emb2d, tsne_idx = plot_paired_tsne_same_embedding(
                        feats, y, g,
                        title_base=title_base,
                        save_dir=out_dir,
                        file_stem=file_stem,
                        seed=seed,
                        max_points=args.tsne_max_points,
                        perplexity=args.tsne_perplexity
                    )

                    meta = {
                        "name": name,
                        "dataset": dataset,
                        "ipc": int(args.ipc),
                        "fair_crt": fair_crt,
                        "model_eval": str(model_eval),
                        "split": "test",
                        "note": "feats are penultimate activations; y=task label; g=demographic group"
                    }
                    tsne_params = {
                        "seed": int(seed),
                        "perplexity": int(args.tsne_perplexity),
                        "max_points": int(args.tsne_max_points),
                        "init": "pca",
                        "learning_rate": "auto",
                    }

                    save_tsne_dump(
                        dump_path, feats, y, g, meta,
                        tsne_emb2d=emb2d,
                        tsne_indices=tsne_idx,
                        tsne_params=tsne_params
                    )


if __name__ == "__main__":
    def save_random_state():
        state = {
            "torch": torch.get_rng_state(),
            "np": np.random.get_state(),
            "random": random.getstate(),
        }
        if torch.cuda.is_available():
            state["cuda"] = torch.cuda.get_rng_state_all()
        else:
            state["cuda"] = None
        return state

    def load_random_state(state):
        torch.set_rng_state(state["torch"])
        np.random.set_state(state["np"])
        random.setstate(state["random"])
        if torch.cuda.is_available() and state.get("cuda") is not None:
            torch.cuda.set_rng_state_all(state["cuda"])

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    random_state = save_random_state()
    main()
