import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random
from sklearn.manifold import trustworthiness
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
    "--weights-path", type=str, required=True,
    help="Path to your .pth file with the full ResNet+MLP checkpoint"
)
parser.add_argument(
    "--output-prefix",
    type=str,
    required=True,
    help="Same prefix you used in extract_features.py (--output-prefix)"
)

args = parser.parse_args()
PREFIX = args.output_prefix

# ------------------------------------------------------------------
# 1) Set random seeds for reproducibility
# ------------------------------------------------------------------
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# ------------------------------------------------------------------
# 2) Device configuration
# ------------------------------------------------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ------------------------------------------------------------------
# 3) Hyperparameters
# ------------------------------------------------------------------
batch_size           = 64
learning_rate        = 1e-4
lambda_pull          = 1.0
lambda_frozen        = 50.0
num_clusters         = 10
margin               = 20.0
sigma_threshold      = 3.0

unlearn_count        = 500
unlearn_chunk_size   = 100
iters_per_chunk      = 150000    # max iters

# ------------------------------------------------------------------
# 4) Paths to pre-extracted features/labels
# ------------------------------------------------------------------
train_features_path = f"{PREFIX}_train_features_cross_ent.npy"
train_labels_path   = f"{PREFIX}_train_labels_cross_ent.npy"
test_features_path  = f"{PREFIX}_test_features_cross_ent.npy"
test_labels_path    = f"{PREFIX}_test_labels_cross_ent.npy"

# ------------------------------------------------------------------
# 5) Dataset for pre-extracted features
# ------------------------------------------------------------------
class FeatureDataset(Dataset):
    def __init__(self, features, labels):
        self.features = torch.tensor(features).float()
        self.labels   = torch.tensor(labels).long()
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

# ------------------------------------------------------------------
# 6) Define a 3-layer residual MLP
# ------------------------------------------------------------------
class MLPResidual3(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=512, output_dim=512):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        self.residual = nn.Linear(input_dim, output_dim)
    def forward(self, x):
        out = self.relu(self.fc1(x))
        out = self.relu(self.fc2(out))
        out = self.fc3(out)
        return out + self.residual(x)


model_global = MLPResidual3().to(device)
# Load the full model checkpoint
checkpoint_path = args.weights_path
full_state_dict = torch.load(checkpoint_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))


print("Full model state dict keys:")
for key in full_state_dict.keys():
    print(key)


mlp_state_dict = {}
prefix = "2."  
for key, value in full_state_dict.items():
    if key.startswith(prefix):
        new_key = key[len(prefix):]
        mlp_state_dict[new_key] = value

print("Extracted MLPResidual state dict keys:")
for key in mlp_state_dict.keys():
    print(key)

# Load the filtered state dictionary into your new MLPResidual instance
missing_keys, unexpected_keys = model_global.load_state_dict(mlp_state_dict, strict=False)
print("Missing keys:", missing_keys)
print("Unexpected keys:", unexpected_keys)

# Compare each key's tensor in the filtered dict and model's state_dict
loaded_state_dict = model_global.state_dict()

for key in mlp_state_dict.keys():
    if torch.allclose(mlp_state_dict[key], loaded_state_dict[key]):
        print(f"{key}: Loaded correctly.")
    else:
        print(f"{key}: Mismatch detected!")


# ------------------------------------------------------------------
# 7) Loss definitions & utilities
# ------------------------------------------------------------------
def pull_to_outlier_location(current, target):
    return F.mse_loss(current, target)

def compute_outlier_targets_sign_corner(unl_emb, all_emb, margin):
    device = unl_emb.device
    N, D = unl_emb.shape
    gmin = all_emb.min(dim=0).values.to(device)
    gmax = all_emb.max(dim=0).values.to(device)
    signs = torch.randint(0, 2, (N, D), device=device) * 2 - 1
    return torch.where(signs > 0,
                       gmax.unsqueeze(0) + margin,
                       gmin.unsqueeze(0) - margin)

def check_outside_box(z, gmin, gmax, margin):
    return torch.all((z <= gmin - margin) | (z >= gmax + margin))

def do_tsne_before_after(Z_before, Z_after, Z_targets, cLabels, unlearn_indices, num_clusters=10):
    from sklearn.manifold import TSNE
    if isinstance(Z_before, torch.Tensor):  Z_before = Z_before.cpu().numpy()
    if isinstance(Z_after, torch.Tensor):   Z_after  = Z_after.detach().cpu().numpy()
    if isinstance(Z_targets, torch.Tensor): Z_targets= Z_targets.cpu().numpy()
    N = Z_before.shape[0]; K = Z_targets.shape[0]
    combined = np.concatenate([Z_before, Z_after, Z_targets], axis=0)
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    coords = tsne.fit_transform(combined)
    def plot_block(name, start, end, unlearn_set):
        plt.figure(figsize=(10,8))
        cmap = plt.cm.get_cmap("tab10", num_clusters)
        block = coords[start:end]
        for c in range(num_clusters):
            mask = (cLabels == c)
            idxs = [i for i in np.where(mask)[0] if (start+i) not in unlearn_set]
            if idxs:
                plt.scatter(block[idxs,0], block[idxs,1], color=cmap(c), alpha=0.6, s=15)
        un_idxs = [i-start for i in unlearn_set if start<=i<end]
        plt.scatter(block[un_idxs,0], block[un_idxs,1],
                    color="yellow", edgecolor="black", s=50, label="Unlearned")
        out_idxs = list(range(2*N,2*N+K))
        plt.scatter(coords[out_idxs,0], coords[out_idxs,1],
                    color="black", marker="s", s=40, label="Outlier Target")
        plt.title(f"t-SNE {name} UNLEARNING"); plt.xlabel("Dim1"); plt.ylabel("Dim2")
        plt.legend(); plt.tight_layout()
        save_path= name + "_cifar10.png"
        if save_path:
          plt.savefig(save_path, dpi=200)
        plt.show()
    before_set = set(unlearn_indices)
    after_set  = {i+N for i in unlearn_indices}
    plot_block("BEFORE", 0, N, before_set)
    plot_block("AFTER",  N, 2*N, after_set)

def compute_cluster_stats(emb, labels, num_clusters):
    means, stds = {}, {}
    for k in range(num_clusters):
        mask = (labels==k)
        if mask.sum()>0:
            data = emb[mask]
            means[k] = data.mean(dim=0)
            stds[k]  = data.std(dim=0)
    return means, stds

def print_unlearn_outside_dims(model, feats_1024, unlearn_indices, margin):
    model.eval()
    with torch.no_grad():
        emb = model(feats_1024.to(device))
    N   = feats_1024.shape[0]
    non = sorted(set(range(N)) - set(unlearn_indices))
    gmin= emb[non].min(dim=0).values
    gmax= emb[non].max(dim=0).values
    for i in unlearn_indices:
        z = emb[i]
        diffs = torch.where(z<gmin, gmin-z, torch.where(z>gmax, z-gmax, torch.zeros_like(z)))
        outside_dims = (diffs>0).sum().item()
        sum_margin   = diffs.sum().item()
        print(f"Sample {i}: sum_margin={sum_margin:.4f}, outside_dims={outside_dims}/{z.numel()}")

def check_unlearn_samples_outside_clusters(unl_emb, cluster_means, cluster_stds, sigma_threshold=3.0):
    """
    For each unlearned embedding in unl_emb, finds its nearest cluster (by cluster_means)
    and prints the percentage of dimensions that lie beyond sigma_threshold · std in that cluster.
    """
    clusters = list(cluster_means.keys())
    # stack the per‐cluster means into a tensor [num_clusters, dim]
    means = torch.stack([cluster_means[k] for k in clusters])
    # compute distances from each unlearned point to each cluster‐mean
    dists = torch.cdist(unl_emb, means)
    # for each point, pick the index of the nearest cluster
    closest = torch.argmin(dists, dim=1)
    # now report how many dims exceed the threshold
    for i, cl in enumerate(closest):
        mv   = cluster_means[clusters[cl]]
        sv   = cluster_stds[clusters[cl]]
        diff = torch.abs(unl_emb[i] - mv)
        pct  = (diff > sigma_threshold * sv).float().mean().item() * 100
        print(f"Unlearn sample {i}: {pct:.2f}% dims >{sigma_threshold}-sigma")

# ------------------------------------------------------------------
# 8) MAIN Training Pipeline — 5-MLP version + unified embedding & eval
# ------------------------------------------------------------------
def main():
    # load train data
    feats = np.load(train_features_path)
    labs  = np.load(train_labels_path)
    all_labels = torch.tensor(labs).long()     
    ds    = FeatureDataset(feats, labs)
    loader= DataLoader(ds, batch_size=batch_size, shuffle=False)

    # build original bounding box
    model_global.eval()
    with torch.no_grad():
        all_z = torch.cat([model_global(x.to(device)).cpu() for x,_ in loader],0)
    gmin = all_z.min(dim=0).values
    gmax = all_z.max(dim=0).values

    # split indices
    N = len(ds)
    idxs = list(range(N)); random.shuffle(idxs)
    un_idxs = idxs[:unlearn_count]
    rem_idxs= idxs[unlearn_count:]

    targets_unlearn = compute_outlier_targets_sign_corner(
            all_z[un_idxs].to(device),  
            all_z.to(device),            
            margin                      
    ).cpu()                            

    idx2pos = {idx: k for k, idx in enumerate(un_idxs)}

    # chunk
    chunks = [un_idxs[i:i+unlearn_chunk_size] for i in range(0, len(un_idxs), unlearn_chunk_size)]

    mlps = []
    for cid, chunk in enumerate(chunks, 1):
        print(f"\n→ Starting chunk {cid}/{len(chunks)}")
        mlp = MLPResidual3().to(device)
        mlp.load_state_dict(model_global.state_dict())
        opt = torch.optim.Adam(mlp.parameters(), lr=learning_rate)

        T = targets_unlearn[[idx2pos[i] for i in chunk]].to(device)

        X_u = torch.tensor(feats[chunk]).to(device)
        X_r = torch.tensor(feats[rem_idxs]).to(device)
        Z_r = all_z[rem_idxs].to(device)

        # train
        mlp.train()
        for it in range(1, iters_per_chunk+1):
            opt.zero_grad()
            pull   = pull_to_outlier_location(mlp(X_u), T)
            frozen = F.mse_loss(mlp(X_r), Z_r)
            loss   = lambda_pull*pull + lambda_frozen*frozen
            loss.backward()
            opt.step()

            if it % 500 == 0:
                print(f"[chunk {cid}] it={it}/{iters_per_chunk}  pull={pull.item():.4f}  frozen={frozen.item():.4f}")
            if pull.item()<20 and frozen.item()<0.01:
                print(f"  ↳ Early stop at chunk {cid}, iter {it}")
                break

        mlps.append(mlp)

    # ------------------------------------------------------------------
    # 9) BUILD UNIFIED EMBEDDING
    # ------------------------------------------------------------------
    with torch.no_grad():
        # remaining: average across mlps
        rem_stack = torch.stack(
            [ m(torch.tensor(feats[rem_idxs]).to(device)).cpu()
              for m in mlps ], dim=0
        )  
        avg_rem = rem_stack.mean(dim=0) 

        # unlearn: pick mlp whose embedding is closest to its outlier target
        unl_outs = []
        for i, idx in enumerate(un_idxs):

            tgt = targets_unlearn[idx2pos[idx]].to(device)        
            choices = torch.stack([
                m(torch.tensor(feats[idx]).unsqueeze(0).to(device)).squeeze(0)
                for m in mlps
            ], dim=0)
            best = torch.argmin((choices - tgt).norm(dim=1))
            unl_outs.append(choices[best].cpu())
        avg_unl = torch.stack(unl_outs)  

    # ------------------------------------------------------------------
    # 9) UNIFIED‐SPACE EVALUATION
    # ------------------------------------------------------------------
    # 9.1 Build unified embedding
    unified_emb = all_z.clone()
    unified_emb[rem_idxs] = avg_rem
    unified_emb[un_idxs]  = avg_unl

    # 9.2 KMeans over unified space
    km_u = KMeans(n_clusters=num_clusters, random_state=42)
    cluster_labels = torch.tensor(
        km_u.fit_predict(unified_emb.numpy()), dtype=torch.long
    )

    # 9.3 3-sigma check on unlearned
    means, stds = compute_cluster_stats(
        unified_emb[rem_idxs], cluster_labels[rem_idxs], num_clusters
    )
    print("\n=== 3-sigma for unlearned samples ===")
    check_unlearn_samples_outside_clusters(
        unified_emb[un_idxs], means, stds, sigma_threshold
    )

    # 9.4 Prototype accuracy on remaining
    def compute_prototypes(embs, labs):
        protos, counts = {}, {}
        for e, l in zip(embs, labs):
            l = int(l)
            if l not in protos:
                protos[l] = e.clone()
                counts[l] = 1
            else:
                protos[l] += e
                counts[l] += 1
        for lbl in protos:
            protos[lbl] /= counts[lbl]
        return protos

    rem_labs   = all_labels[rem_idxs]
    prototypes = compute_prototypes(unified_emb[rem_idxs], rem_labs)

    def predict_proto(e):
        best_lbl, best_d = None, float('inf')
        for lbl, p in prototypes.items():
            d = torch.norm(e - p)
            if d < best_d:
                best_lbl, best_d = lbl, d
        return best_lbl

    correct_rem = sum(
        predict_proto(unified_emb[i]) == int(rem_labs[idx])
        for idx, i in enumerate(rem_idxs)
    )
    print(f"[Remaining] Prototype accuracy: {correct_rem/len(rem_labs)*100:.2f}%")

    # 9.4a) Unlearned‐samples “accuracy” via 3σ → prototype rule
    num_unl = len(un_idxs)
    correct_unl = 0
    for idx in un_idxs:
        emb = unified_emb[idx]

        # detect “unknown” if outside 3σ of every cluster (over all dims)
        in_any_cluster = False
        for k, mu in means.items():
            sigma = stds[k]
            if torch.all(torch.abs(emb - mu) <= sigma_threshold * sigma):
                in_any_cluster = True
                break

        if in_any_cluster:
            # classify via nearest‐prototype
            pred = predict_proto(emb)
            if pred == int(all_labels[idx]):
                correct_unl += 1
        # else: it’s declared “unknown” and so always counted as wrong

    acc_unl = correct_unl / num_unl * 100
    print(f"[Unlearned] Classification accuracy: {acc_unl:.2f}%")


    # 9.5 Test‐set via unified embeddings
    test_feats = np.load(test_features_path)
    test_labs  = np.load(test_labels_path)
    avg_test   = []
    for x in test_feats:
        tx = torch.tensor(x).to(device)
        zs = torch.stack([m(tx.unsqueeze(0)).squeeze(0).cpu() for m in mlps])
        avg_test.append(zs.mean(dim=0))
    avg_test = torch.stack(avg_test)

    correct_test = sum(
        predict_proto(avg_test[i]) == int(test_labs[i])
        for i in range(len(test_labs))
    )
    print(f"[Test] Prototype accuracy: {correct_test/len(test_labs)*100:.2f}%")

    # 9.6 Trustworthiness & L2 drift (remaining)
    tw = trustworthiness(
        all_z[rem_idxs].numpy(),
        unified_emb[rem_idxs].numpy(),
        n_neighbors=5
    )
    l2 = (unified_emb[rem_idxs] - all_z[rem_idxs]) \
         .norm(dim=1).mean().item()
    print(f"[After] Trustworthiness K=5: {tw:.4f}")
    print(f"[After] L2 drift: {l2:.4f}")

    # 9.7 t-SNE on unified
    do_tsne_before_after(
        all_z,                   # before
        unified_emb,             # after = unified
        targets_unlearn,                                    
        cluster_labels.numpy(),  # cluster assignments
        un_idxs,
        num_clusters
    )


if __name__=="__main__":
    main()
