import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import random
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

from sklearn.manifold import trustworthiness
import argparse
parser = argparse.ArgumentParser()

parser.add_argument(
    "--data-prefix",
    type=str,
    required=True,
    help="Common prefix for your .npy feature/label files (e.g. 'drive/MyDrive/MU_Data/DINO_COU_cifar10')"
)
args = parser.parse_args()

PREFIX = args.data_prefix

# ------------------------------------------------------------------
# Set random seeds for reproducibility
# ------------------------------------------------------------------
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# ------------------------------------------------------------------
# Device configuration
# ------------------------------------------------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ------------------------------------------------------------------
# Hyperparameters
# ------------------------------------------------------------------
batch_size         = 64
learning_rate      = 0.0001
num_epochs         = 5
temperature        = 0.07   # For contrastive loss
num_selected_dims  = 5
threshold_corr     = 0.3
sigma_bands        = [1,2,3]
num_clusters       = 10     # e.g. CIFAR-10

# ------------------------------------------------------------------
# Paths to pre-extracted features and labels
# ------------------------------------------------------------------
train_features_path = f"{PREFIX}_train_features.npy"
train_labels_path   = f"{PREFIX}_train_labels.npy"

# ------------------------------------------------------------------
# Dataset for pre-extracted features
# ------------------------------------------------------------------
class FeatureDataset(Dataset):
    def __init__(self, features, labels):
        self.features = torch.tensor(features).float()
        self.labels   = torch.tensor(labels).long()
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

# ------------------------------------------------------------------
# Load features & labels
# ------------------------------------------------------------------
train_features_np = np.load(train_features_path)
train_labels_np   = np.load(train_labels_path)

train_dataset = FeatureDataset(train_features_np, train_labels_np)
train_loader  = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)

# ------------------------------------------------------------------
# Simple linear model: 1024 -> 512
# ------------------------------------------------------------------
class LinearModel(nn.Module):
    def __init__(self, input_dim=1024, output_dim=512):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
    def forward(self, x):
        return self.linear(x)

model     = LinearModel(1024,512).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)



def compute_class_stats(features, labels, num_classes):
    kmeans = KMeans(n_clusters=num_classes, random_state=42)
    feats_np = features.cpu().numpy()
    kmeans.fit(feats_np)

    class_means = {}
    class_stds  = {}
    for i in range(num_classes):
        mask = (kmeans.labels_ == i)
        cluster_feats = features[mask]
        if cluster_feats.size(0) == 0:
            continue
        std = torch.std(cluster_feats, dim=0)
        std[std==0] = 1e-6
        class_means[i] = torch.tensor(kmeans.cluster_centers_[i]).float().to(device)
        class_stds[i]  = std.to(device)
    return class_means, class_stds, kmeans.labels_

def select_prototypes(features_1024, features_512, labels, cluster_labels, class_means, model):
    prototypes_in_1024 = {}
    prototypes_in_512  = {}

    unique_cidx = np.unique(cluster_labels)
    for cidx in unique_cidx:
        mask_ = (cluster_labels == cidx)
        sub_1024 = features_1024[mask_]
        sub_512  = features_512[mask_].to(device)
        center   = class_means[cidx].unsqueeze(0)
        dists    = torch.norm(sub_512 - center, dim=1)
        idx_min  = torch.argmin(dists)

        proto_1024 = sub_1024[idx_min].clone()
        with torch.no_grad():
            proto_512  = model(proto_1024.unsqueeze(0).to(device)).squeeze(0).clone()

        prototypes_in_1024[cidx] = proto_1024
        prototypes_in_512[cidx]  = proto_512
    return prototypes_in_1024, prototypes_in_512


def pull_to_second_nearest_cos(features_512, labels, prototypes_512_dict, num_clusters):
    device = features_512.device
    norm_feat = F.normalize(features_512, dim=1)

    cluster_keys = sorted(prototypes_512_dict.keys())
    proto_tensors = []
    for cidx in cluster_keys:
        p = prototypes_512_dict[cidx].to(device)
        p = F.normalize(p.unsqueeze(0), dim=1).squeeze(0)
        proto_tensors.append(p)
    proto_tensors = torch.stack(proto_tensors, dim=0)

    loss_list = []
    for i in range(norm_feat.size(0)):
        feat_i   = norm_feat[i]
        cos_sims = torch.matmul(feat_i.unsqueeze(0), proto_tensors.T).squeeze(0)

        label_i = labels[i].item()
        if label_i in cluster_keys:
            skip_idx = cluster_keys.index(label_i)
            cos_sims[skip_idx] = -999.0  # skip

        top2_vals, _ = torch.topk(cos_sims, k=2)
        second_best_cos = top2_vals[1]
        scaled = 1.0 - second_best_cos
        loss_list.append(scaled)
    if len(loss_list)>0:
        return torch.mean(torch.stack(loss_list))
    else:
        return torch.tensor(0.0, device=device)

def pull_to_original_cos(current_512, reference_512):
    c = F.normalize(current_512, dim=1)
    r = F.normalize(reference_512, dim=1)
    cos_sims = (c * r).sum(dim=1)
    return (1.0 - cos_sims).mean()

class SupervisedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
    def forward(self, feats, labels):
        device = feats.device
        labels = labels.view(-1, 1)
        feats  = F.normalize(feats, dim=1)

        mask   = torch.eq(labels, labels.T).float().to(device)
        anchor_dot = torch.div(torch.matmul(feats, feats.T), self.temperature)

        max_logits,_= torch.max(anchor_dot, dim=1, keepdim=True)
        logits     = anchor_dot - max_logits.detach()

        logits_mask= torch.ones_like(mask)
        idxs = torch.arange(labels.size(0)).view(-1,1).to(device)
        logits_mask.scatter_(1, idxs, 0)
        mask = mask*logits_mask

        exp_logits= torch.exp(logits)*logits_mask
        log_prob  = logits - torch.log(exp_logits.sum(1, keepdim=True)+1e-12)

        mean_log_prob= (mask*log_prob).sum(1)/(mask.sum(1)+1e-12)
        return -mean_log_prob.mean()



def do_combined_tsne_and_plot_three_phases(
    orig_512, mid_512, final_512, cLabels, num_clusters
):
    """
    Run ONE TSNE on [orig_512; mid_512; final_512] 
    and then produce THREE separate scatter plots 
    (Original, Pulled, Unlearned) with consistent coloring.

    Arguments:
      orig_512:   (N, 512)  Original embeddings
      mid_512:    (N, 512)  “Pulled” embeddings
      final_512:  (N, 512)  “Unlearned” embeddings
      cLabels:    (N,)      The integer cluster IDs for each sample
      num_clusters: e.g. 10
    """
    import matplotlib.pyplot as plt
    from sklearn.manifold import TSNE
    
    # 1) Concatenate
    # Shape => [3*N, 512]
    combined_512 = torch.cat([orig_512, mid_512, final_512], dim=0).cpu()
    
    # 2) Duplicate cluster labels three times
    # shape => [3*N]
    combined_labels = np.concatenate([cLabels, cLabels, cLabels], axis=0)
    
    # 3) Run TSNE once
    print("Running single TSNE for all three phases (orig / mid / final).")
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    combined_2d = tsne.fit_transform(combined_512.numpy())
    
    # 4) Split 2D results
    N = orig_512.shape[0]
    coords_orig  = combined_2d[:N]
    coords_mid   = combined_2d[N:2*N]
    coords_final = combined_2d[2*N:]
    labels_orig  = combined_labels[:N]
    labels_mid   = combined_labels[N:2*N]
    labels_final = combined_labels[2*N:]
    
    # 5) Plot each set in a separate subplot (same color scheme)
    fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
    
    cmap = plt.cm.get_cmap("tab10", num_clusters)
    
    # Helper to do the scatter for each subplot
    def scatter_clusters(ax, coords_2d, lab_array, title_str):
        for cid in range(num_clusters):
            mask = (lab_array == cid)
            ax.scatter(
                coords_2d[mask, 0],
                coords_2d[mask, 1],
                color=cmap(cid),
                alpha=0.6,
                s=15,
                label=f"Cluster {cid}"
            )
        ax.set_title(title_str)
        #ax.legend()
    
    scatter_clusters(axes[0], coords_orig,  labels_orig,  "Original")
    scatter_clusters(axes[1], coords_mid,   labels_mid,   "After Learning with arbitrary objective")
    scatter_clusters(axes[2], coords_final, labels_final, "After Unlearning with COU")
    
    for ax in axes:
        ax.set_xlabel("t-SNE Dim1")
        ax.set_ylabel("t-SNE Dim2")
    plt.tight_layout()
    plt.show()



# ------------------------------------------------------------------
def main():
    # 1) Train with Supervised Contrastive
    supcon = SupervisedContrastiveLoss(temperature=temperature)
    for ep in range(num_epochs):
        model.train()
        tot_loss = 0.0
        for feats, labs in train_loader:
            feats = feats.to(device)
            labs  = labs.to(device)
            optimizer.zero_grad()
            out_512 = model(feats)
            loss    = supcon(out_512, labs)
            loss.backward()
            optimizer.step()
            tot_loss += loss.item()
        print(f"[Pre-Training] Epoch {ep+1}/{num_epochs}, Loss={tot_loss/len(train_loader):.4f}")

    # 2) Collect entire dataset embeddings => store "original"
    model.eval()
    all_feats_1024, all_feats_512, all_labels = [], [], []
    with torch.no_grad():
        for feats, labs in train_loader:
            feats = feats.to(device)
            out_512 = model(feats)
            all_feats_1024.append(feats.cpu())
            all_feats_512.append(out_512.cpu())
            all_labels.append(labs.cpu())
    all_feats_1024 = torch.cat(all_feats_1024, dim=0)
    all_feats_512  = torch.cat(all_feats_512, dim=0)
    all_labels     = torch.cat(all_labels,    dim=0)

    all_original_512 = all_feats_512.clone()

    # 3) K-Means => cluster centers + stdev, and cluster labels for coloring
    cmeans, cstds, cLabels = compute_class_stats(all_feats_512, all_labels, num_clusters)
    # cLabels is a numpy array with cluster assignments for each sample

    # 4) Prototypes + dimension selection + supports
    proto_in_1024, proto_in_512 = select_prototypes(
        all_feats_1024, all_feats_512, all_labels, cLabels, cmeans, model
    )
    
    # 5) Pull -> pick 100 random samples
    indices_for_pull = random.sample(range(all_feats_1024.size(0)), 100)
    pull_dataset     = Subset(train_dataset, indices_for_pull)
    pull_loader      = DataLoader(pull_dataset, batch_size=batch_size, shuffle=True)

    original_1024_100 = all_feats_1024[indices_for_pull].clone()
    original_512_100  = all_feats_512[indices_for_pull].clone()

    # 6) Pull-toward
    for ep in range(1300):
        model.train()
        for feats, labs in pull_loader:
            feats = feats.to(device)
            labs  = labs.to(device)
            
            #with torch.no_grad():
            out_512 = model(feats)

            pull_loss = pull_to_second_nearest_cos(out_512, labs, proto_in_512, num_clusters)

           
            total_loss = pull_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

    # measure distance after pull on the 100
    model.eval()
    with torch.no_grad():
        mid_512_100 = model(original_1024_100.to(device)).cpu()
    dist_mid_100 = torch.norm(mid_512_100 - original_512_100, dim=1).mean().item()
    print(f"L2 distance after pull (before unlearning) => {dist_mid_100:.4f}")

    # measure entire dataset after pull
    final_all_512_after_pull = []
    model.eval()
    with torch.no_grad():
        #full_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
        for feats, _ in train_loader:
            feats = feats.to(device)
            out_512 = model(feats)
            final_all_512_after_pull.append(out_512.cpu())
    final_all_512_after_pull = torch.cat(final_all_512_after_pull, dim=0)

    dist_all_after_pull = torch.norm(final_all_512_after_pull - all_original_512, dim=1).mean().item()
    print(f"L2 distance (ENTIRE dataset) after pull => {dist_all_after_pull:.4f}")


    # 7) Unlearning => pull them back to original
    for param_group in optimizer.param_groups:
        param_group['lr'] = 1e-4  # smaller lr
    ep = 0
    dist_final_temp = 9999.0
    threshold_dist  = 2.0

    while dist_final_temp > threshold_dist:
        offset = 0
        while offset < len(indices_for_pull):
            bsz = min(batch_size, len(indices_for_pull) - offset)
            current_idxs = indices_for_pull[offset : offset + bsz]

            feats_1024_batch = all_feats_1024[current_idxs].to(device)
            
            #with torch.no_grad():
            out_512_batch    = model(feats_1024_batch)
            
            ref_512_batch    = original_512_100[offset : offset + bsz].to(device)

            restore_loss = pull_to_original_cos(out_512_batch, ref_512_batch)

            total_loss = restore_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            offset += bsz

        # measure distance for the 100
        with torch.no_grad():
            final_512_temp = model(original_1024_100.to(device)).cpu()
        dist_final_temp = torch.norm(final_512_temp - original_512_100, dim=1).mean().item()

        ep += 1
        if ep % 100 == 0:
            print(f"[Unlearning Phase] iter={ep}, Loss={total_loss.item():.4f}, dist={dist_final_temp:.4f}")

    # measure final distance for the 100
    with torch.no_grad():
        final_512_100 = model(original_1024_100.to(device)).cpu()
    dist_final_100 = torch.norm(final_512_100 - original_512_100, dim=1).mean().item()
    print(f"L2 distance (100 samples) after unlearning => {dist_final_100:.4f}")

    # measure entire dataset final
    final_all_512 = []
    model.eval()
    with torch.no_grad():
        #full_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
        for feats, _ in train_loader:
            feats = feats.to(device)
            out_512 = model(feats)
            final_all_512.append(out_512.cpu())
    final_all_512 = torch.cat(final_all_512, dim=0)

    dist_all_final = torch.norm(final_all_512 - all_original_512, dim=1).mean().item()
    print(f"L2 distance (ENTIRE dataset) after unlearning => {dist_all_final:.4f}")


    k_neighbors = 5

    # Trustworthiness for “Pulled” embeddings vs Original
    pulled_trust = trustworthiness(
        X=all_original_512.numpy(),       # Original high-dim
        X_embedded=final_all_512_after_pull.numpy(),
        n_neighbors=k_neighbors
    )
    print(f"Trustworthiness (Original vs Pulled), k={k_neighbors} => {pulled_trust:.4f}")

    # Trustworthiness for “Unlearned” (final) embeddings vs Original
    unlearned_trust = trustworthiness(
        X=all_original_512.numpy(),
        X_embedded=final_all_512.numpy(),
        n_neighbors=k_neighbors
    )
    print(f"Trustworthiness (Original vs Unlearned), k={k_neighbors} => {unlearned_trust:.4f}")

    do_combined_tsne_and_plot_three_phases(all_original_512, final_all_512_after_pull, final_all_512, cLabels, num_clusters)


if __name__ == "__main__":
    main()
