import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import random
import numpy as np
from sklearn.manifold import trustworthiness, TSNE
import matplotlib.pyplot as plt

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--weights_path', type=str, required=True, help='Path for saved weights')
args = parser.parse_args()


# -----------------------------------------
# 1) Set seeds & device
# -----------------------------------------
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# -----------------------------------------
# 2) Hyperparameters
# -----------------------------------------
unlearn_iters = 100        
batch_size     = 64
learning_rate  = 1e-3
lambda_neg     = 1.0            # Negative CE weight

# -----------------------------------------
# 3) Data transforms
# -----------------------------------------
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# -----------------------------------------
# 4) MLPResidual block
# -----------------------------------------
class MLPResidual(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=512, output_dim=512):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        self.residual = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        out = self.relu(self.fc1(x))
        out = self.relu(self.fc2(out))
        out = self.fc3(out)
        return out + self.residual(x)

# -----------------------------------------
# 5) Model loader
# -----------------------------------------
def build_model():
    base = models.resnet18(pretrained=False)
    backbone = nn.Sequential(*list(base.children())[:-1])
    model = nn.Sequential(
        backbone,
        nn.Flatten(),
        MLPResidual(),
        nn.Linear(512,10)
    ).to(device)
    return model

# -----------------------------------------
# 6) Embedding extractor (backbone+MLPResidual)
# -----------------------------------------
def extract_embeddings(model, loader):
    model.eval()
    embed_net = nn.Sequential(*list(model.children())[:-1]).to(device)
    all_embs = []
    with torch.no_grad():
        for x, _ in loader:
            x = x.to(device)
            z = embed_net(x)
            all_embs.append(z.cpu().numpy())
    return np.vstack(all_embs)

# -----------------------------------------
# 7) Accuracy helper
# -----------------------------------------
def eval_acc(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return 100 * correct / total

# -----------------------------------------
# 8) Main NegGrad unlearning on full CIFAR-10
# -----------------------------------------
def main():
    # -- load data
    train_full = datasets.CIFAR10(
        root="./data", train=True, download=True, transform=transform
    )
    test_set = datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform
    )

    # -- split 1% for unlearning
    n_total = len(train_full)
    n_unlearn = int(0.01 * n_total)
    indices = list(range(n_total))
    random.shuffle(indices)
    forget_idx = indices[:n_unlearn]
    remain_idx = indices[n_unlearn:]

    forget_loader = DataLoader(
        Subset(train_full, forget_idx), batch_size=batch_size, shuffle=True, num_workers=4
    )
    remain_loader = DataLoader(
        Subset(train_full, remain_idx), batch_size=batch_size, shuffle=False, num_workers=4
    )
    test_loader = DataLoader(
        test_set, batch_size=batch_size, shuffle=False, num_workers=4
    )

    # -- build and optionally load pretrained model
    ref_model = build_model()
    
    checkpoint_path = args.weights_path
    print(f"Loading weights from {checkpoint_path} …")
    pou_state = torch.load(checkpoint_path, map_location=device)
    ref_model.load_state_dict(pou_state)

    emb_before = extract_embeddings(ref_model, remain_loader)

    # -- perform NegGrad unlearning
    model = ref_model
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for it in range(unlearn_iters):
        for x_u, y_u in forget_loader:
            x_u, y_u = x_u.to(device), y_u.to(device)
            optimizer.zero_grad()
            loss = criterion(model(x_u), y_u)
            # ascend loss for forget set
            (-lambda_neg * loss).backward()
            optimizer.step()

    # -- final evaluations
    print("\n-- Post-NegGrad Accuracies --")
    print(f"Test set:      {eval_acc(model, test_loader):.2f}%")
    print(f"Remaining set: {eval_acc(model, remain_loader):.2f}%")
    print(f"Forget set:    {eval_acc(model, forget_loader):.2f}%")

    # -- metrics on embeddings
    emb_after = extract_embeddings(model, remain_loader)
    tw = trustworthiness(emb_before, emb_after, n_neighbors=5)
    drift = np.linalg.norm(emb_after - emb_before, axis=1).mean()
    print(f"\n[NegGrad] Trustworthiness: {tw:.4f}")
    print(f"[NegGrad] Mean L2 drift:   {drift:.4f}")

    # -- optional TSNE plot
    combined = np.vstack([emb_before, emb_after])
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    vis = tsne.fit_transform(combined)
    N = emb_before.shape[0]
    fig, axes = plt.subplots(1,2,figsize=(12,5))
    for ax, start in zip(axes, [0, N]):
        pts = vis[start:start+N]
        ax.scatter(pts[:,0], pts[:,1], s=10, alpha=0.6)
        ax.set_title('Before' if start==0 else 'After')
    plt.tight_layout()
    save_path="NegGrad_Cifar10_1_percent.png"
    if save_path:
        plt.savefig(save_path, dpi=300)
    plt.show()

if __name__ == "__main__":
    main()
