import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
import random
import re
from pathlib import Path

# --- New Imports for t-SNE ---
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import matplotlib.colors as mcolors

# --- Assuming these exist in your 'utils.py'. If not, ensure utils is in your path ---
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
def get_features_and_metadata(model, dataloader, device):
    """
    Extract penultimate features and also save:
      - sensitive attribute (demographic group)
      - task label (class)
    Robust to different batch structures: expects batch[0]=img, batch[1]=y, batch[2]=group (if exists).
    Returns:
      feats: np.ndarray [N, D]
      y:    np.ndarray [N]
      g:    np.ndarray [N]
    """
    model.eval()
    features_list = []
    y_list = []
    g_list = []

    activation = {}

    def get_activation(name):
        def hook(module, inp, out):
            activation[name] = inp[0].detach().cpu().numpy()
        return hook

    # Attach hook to last linear-like layer (best-effort)
    handle = None
    if hasattr(model, 'linear') and isinstance(model.linear, nn.Module):
        handle = model.linear.register_forward_hook(get_activation('feat'))
    elif hasattr(model, 'fc') and isinstance(model.fc, nn.Module):
        handle = model.fc.register_forward_hook(get_activation('feat'))
    elif hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module):
        handle = model.classifier.register_forward_hook(get_activation('feat'))
    else:
        # fallback: find last nn.Linear anywhere
        last_linear = None
        for m in model.modules():
            if isinstance(m, nn.Linear):
                last_linear = m
        if last_linear is None:
            print("Warning: Could not find a Linear layer to hook for features.")
            return None, None, None
        handle = last_linear.register_forward_hook(get_activation('feat'))

    with torch.no_grad():
        for batch in dataloader:
            imgs = batch[0]
            y = batch[1] if len(batch) >= 2 else None

            if len(batch) >= 3:
                g = batch[2]
            else:
                # If demographic attr not provided by dataloader, fill zeros
                if y is None:
                    g = torch.zeros((imgs.size(0),), dtype=torch.long)
                    y = torch.zeros((imgs.size(0),), dtype=torch.long)
                else:
                    g = torch.zeros_like(y)

            imgs = imgs.to(device, non_blocking=True)
            _ = model(imgs)  # trigger hook

            features_list.append(activation['feat'])
            y_list.append(y.detach().cpu().numpy())
            g_list.append(g.detach().cpu().numpy())

    handle.remove()

    feats = np.concatenate(features_list, axis=0)
    y_all = np.concatenate(y_list, axis=0)
    g_all = np.concatenate(g_list, axis=0)
    return feats, y_all, g_all

def save_tsne_dump(save_path, feats, y, g, meta):
    """
    Save everything needed for later t-SNE / plotting WITHOUT GPU.
    Uses torch.save for convenience (loads easily).
    """
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    dump = {
        "feats": torch.tensor(feats, dtype=torch.float32),  # [N, D]
        "y": torch.tensor(y, dtype=torch.long),             # [N]
        "g": torch.tensor(g, dtype=torch.long),             # [N]
        "meta": meta
    }
    torch.save(dump, save_path)
    print(f"  [Dump] Saved plot-ready dump to: {save_path}")



def get_features(model, dataloader, device):
    """
    Extracts features from the penultimate layer of the model.
    Robust to different batch structures.
    """
    model.eval()
    features_list = []
    sensitive_list = []
    
    # Define hook
    activation = {}
    def get_activation(name):
        def hook(model, input, output):
            # Input to the linear layer is the feature vector
            activation[name] = input[0].detach().cpu().numpy()
        return hook

    # Attach hook to the likely name of the last layer
    if hasattr(model, 'linear'):
        handle = model.linear.register_forward_hook(get_activation('feat'))
    elif hasattr(model, 'fc'):
        handle = model.fc.register_forward_hook(get_activation('feat'))
    elif hasattr(model, 'classifier'):
        handle = model.classifier.register_forward_hook(get_activation('feat'))
    else:
        print("Warning: Could not find linear/fc/classifier layer for t-SNE.")
        return None, None

    with torch.no_grad():
        for batch in dataloader:
            # --- FIX: Safer batch handling ---
            # Instead of assuming len(batch) is 2 or 3, we access indices directly.
            
            # 1. Get Images (Always index 0)
            imgs = batch[0]
            
            # 2. Get Sensitive Attribute (Try index 2, else fallback to 0s)
            if len(batch) >= 3:
                sensitive = batch[2]
            else:
                # If the loader doesn't provide sensitive labels, 
                # we create a dummy placeholder so the code doesn't crash.
                # Ideally, your testloader SHOULD provide this.
                labels = batch[1]
                sensitive = torch.zeros_like(labels)
            
            # Move to device
            imgs = imgs.to(device)
            
            # Forward pass triggers hook
            _ = model(imgs) 
            
            features_list.append(activation['feat'])
            sensitive_list.append(sensitive.numpy())

    handle.remove() # Clean up hook
    
    features = np.concatenate(features_list, axis=0)
    sensitive_attrs = np.concatenate(sensitive_list, axis=0)
    
    return features, sensitive_attrs


def plot_tsne(features, sensitive_attrs, title, save_path):
    """
    Runs t-SNE and saves the plot.
    """
    if features is None: return

    print(f"  [Plotting] Generating t-SNE for {title}...")
    
    # 1. Run t-SNE
    # Limit samples for speed if dataset is huge, e.g. first 2000
    if features.shape[0] > 5000:
        indices = np.random.choice(features.shape[0], 5000, replace=False)
        features = features[indices]
        sensitive_attrs = sensitive_attrs[indices]

    tsne = TSNE(n_components=2, random_state=42, perplexity=30, init='pca', learning_rate='auto')
    X_embedded = tsne.fit_transform(features)
    
    # 2. Plot
    plt.figure(figsize=(8, 6))
    
    # Get unique sensitive groups (colors)
    unique_groups = np.unique(sensitive_attrs)
    # Use a distinct colormap
    colors = plt.cm.get_cmap('tab10', len(unique_groups))
    
    for i, group in enumerate(unique_groups):
        idx = sensitive_attrs == group
        plt.scatter(X_embedded[idx, 0], X_embedded[idx, 1], 
                    color=colors(i), 
                    label=f'PA Group {group}', 
                    s=15, alpha=0.7)

    plt.title(title)
    # plt.legend() # Optional: Comment out if you want it exactly like the paper (no legend inside)
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()
    
    plt.savefig(save_path, dpi=300)
    plt.close()
    print(f"  [Plotting] Saved to {save_path}")


# -------------------------------------------------------------------------
# Main Logic
# -------------------------------------------------------------------------

def main():

    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
    parser.add_argument('--model', type=str, default='ConvNet', help='model')
    parser.add_argument('--ipc', type=int, default=50, help='image(s) per class')
    parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode') 
    parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
    parser.add_argument('--num_eval', type=int, default=5, help='the number of evaluating randomly initialized models')
    parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data') 
    parser.add_argument('--lr_img', type=float, default=1.0, help='learning rate for updating synthetic images')
    parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
    parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
    parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    parser.add_argument('--init', type=str, default='real', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
    parser.add_argument('--data_path', type=str, default='data', help='dataset path')
    parser.add_argument('--save_path', type=str, default='result', help='path to save results')
    parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')
    parser.add_argument('--shuffle', type=bool, default=False, help='distance metric')
    parser.add_argument('--FairDD', action='store_true', help='Enable FairDD')
    parser.add_argument('--group_balance', type=bool, default=False, help='distance metric')
    parser.add_argument('--ALL_data', type=str, default='', help='path to save results')

    args = parser.parse_args()
    args.outer_loop, args.inner_loop = get_loops(args.ipc)
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()
    args.dsa = True

    NAMES = ['CAFE']#,'DM'] # You can add 'FairDD' here if you have results for it
    ALL_DATA = [
        # "CIFAR10_S_90",
        "Colored_FashionMNIST_foreground",
        "Colored_FashionMNIST_background",
        # "Colored_MNIST_foreground",
        # "Colored_MNIST_background",
        # "UTKface",
        # "BFFHQ",
    ]
    
    # --- Loop Structure Fixed ---
    for name in NAMES:
        for dataset in ALL_DATA:
            for fair_crt in ['NoFair','FairDD','NoOrtho']:
                
                args.testMetric = name
                # Loop for IPC
                # for ipc in [10, 50, 100]:
                for ipc in [100,10]:

                    args.ipc = ipc
                    
                    if name == 'DC':
                        args.dsa = False
                    else:
                        args.dsa = True
                    args.dsa = False


                    dump_name = f"dump_{name}_{dataset}_ipc{args.ipc}_{fair_crt}.pt"
                    if os.path.exists(os.path.join('./T-SNE', dump_name)):
                        continue


                    save_path = './results-pt/' + name  + '/'+name +'-'+ fair_crt + '/'
                    if fair_crt == 'FairDD':
                        save_path = save_path + 'FairDD_'
                    elif fair_crt == 'NoOrtho':
                        save_path = save_path + 'Fair_NoOrtho_'
                        
                    save_path = save_path + name + '_' + dataset + '_ipc'  + str(args.ipc) + '/'
                    save_path = save_path + 'res_'+name+'_' + dataset + '_ConvNet_'  + str(args.ipc) + 'ipc.pt'
                    checkpoint = torch.load(save_path, map_location=args.device, weights_only=False)


                    try:
                        image_syn, label_syn = checkpoint['data'][0]
                    except:
                        image_syn, label_syn = checkpoint['data']

                    image_syn = image_syn.to(args.device) 
                    label_syn = label_syn.to(args.device)
                    
                    args.dataset = dataset
                    # get_dataset should return dst_test and testloader
                    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
                    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)
                    load_random_state(random_state) 

                    # (In your snippet you accessed dst_train for this, keeping it same)
                    labels_all = [dst_train[i][1] for i in range(len(dst_train))]
                    color_all = [dst_train[i][2] for i in range(len(dst_train))]
                    args.num_classes = len(np.unique(labels_all))
                    args.num_groups = len(np.unique(color_all))


                    model_eval = model_eval_pool[0]
                    print('-----------------\nEvaluation\nmodel_train = %s, model_eval = %s'%(args.model, model_eval))
                    
                    args.dc_aug_param = get_daparam(args.dataset, args.model, model_eval, args.ipc) 
                    
                    net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) 
                    image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) 
                    image_syn_eval = DiffAugment(image_syn_eval, args.dsa_strategy, seed=seed, param=args.dsa_param)
                    
                    # --- TRAIN THE MODEL (evaluate_synset trains net_eval) ---
                    net_eval, *_ = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args, verbose=False)
                    
                    # =======================================================
                    # START t-SNE VISUALIZATION
                    # =======================================================
            
                    feats, sens_attrs = get_features(net_eval, testloader, args.device)
                    
                    # 2. Define path to save image
                    # Saves in the same folder as the .pt file
                    plot_name = f'tsne_PA_{name}_{dataset}_ipc{args.ipc}_{fair_crt}-gemini.png'
                    plot_path = os.path.join('./T-SNE', plot_name)
                    
                    # 3. Plot
                    if feats is not None:
                        plot_tsne(feats, sens_attrs, title=f"PA t-SNE: {name} on {dataset}", save_path=plot_path)
                        print(plot_path)



                    feats, y, g = get_features_and_metadata(net_eval, testloader, args.device)

                    if feats is None:
                        continue

                    dump_name = f"dump_{name}_{dataset}_ipc{args.ipc}_{fair_crt}.pt"
                    dump_path = os.path.join('./T-SNE', dump_name)

                    meta = {
                        "name": name,
                        "dataset": dataset,
                        "ipc": int(args.ipc),
                        "fair_crt": fair_crt,
                        "model_eval": str(model_eval),
                        "split": "test",
                        "note": "feats are penultimate activations; y=task label; g=demographic group"
                    }

                    save_tsne_dump(dump_path, feats, y, g, meta)

                    # =======================================================
                    # END t-SNE VISUALIZATION
                    # =======================================================


if __name__ == '__main__':
    def save_random_state():
        return {
            'torch': torch.get_rng_state(),
            'np': np.random.get_state(),
            'random': random.getstate(),
            'cuda': torch.cuda.get_rng_state_all()
        }
    def load_random_state(state):
        torch.set_rng_state(state['torch'])
        np.random.set_state(state['np'])
        random.setstate(state['random'])
        torch.cuda.set_rng_state_all(state['cuda'])

    seed=42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    random_state = save_random_state()

    main()