import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import os
import argparse
from models import *
from models.resnet_orig import ResNet18_orig
from models.vgg import VGG
import pandas as pd
import random
import time
import copy
import numpy as np
from torch.utils.data import Dataset
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix

try:
    from sklearn.neighbors import KernelDensity
    _HAVE_SKLEARN_KDE = True
except Exception:
    _HAVE_SKLEARN_KDE = False


def _fit_normal(vals: np.ndarray):
    eps = 1e-9
    mu = float(np.mean(vals)) if vals.size else 0.0
    sd = float(np.std(vals, ddof=1)) if vals.size > 1 else 0.0
    sd = max(sd, eps)
    return {"kind": "normal", "mu": mu, "sd": sd}

def _logpdf_normal(model, x: np.ndarray):
    mu, sd = model["mu"], model["sd"]
    return -0.5 * np.log(2 * np.pi * (sd ** 2)) - 0.5 * ((x - mu) ** 2) / (sd ** 2)

def _fit_kde(vals: np.ndarray, bandwidth: float | None):
    if not _HAVE_SKLEARN_KDE:
        raise ImportError("KDE requires scikit-learn. Install scikit-learn or use kind='normal'.")
    # sklearn expects 2D array shape (n_samples, n_features)
    x = vals.reshape(-1, 1)
    kde = KernelDensity(kernel="gaussian", bandwidth=bandwidth if bandwidth is not None else "scott")
    # If "scott" isn't supported in your sklearn version, consider replacing with a numeric bandwidth.
    # Many sklearn versions accept only float; fallback:
    if isinstance(kde.bandwidth, str):
        # Scott's rule of thumb (1D): bw = n^(-1/5) * std
        n = max(len(vals), 1)
        std = np.std(vals, ddof=1) if n > 1 else 1.0
        kde = KernelDensity(kernel="gaussian", bandwidth=(n ** (-1/5)) * max(std, 1e-9))
    kde.fit(x)
    return {"kind": "kde", "kde": kde}

def _logpdf_kde(model, x: np.ndarray):
    kde = model["kde"]
    return kde.score_samples(x.reshape(-1, 1))  # returns log density

def _fit_distribution(vals: np.ndarray, kind: str = "normal", bandwidth: float | None = None):
    if kind == "normal":
        return _fit_normal(vals)
    elif kind == "kde":
        return _fit_kde(vals, bandwidth)
    else:
        raise ValueError("kind must be 'normal' or 'kde'.")

def _logpdf(model, x: np.ndarray):
    if model["kind"] == "normal":
        return _logpdf_normal(model, x)
    elif model["kind"] == "kde":
        return _logpdf_kde(model, x)
    else:
        raise ValueError("Unknown model kind")


def collect_prob(data_loader, model, unlearn_method, model_path):
    """Collect probabilities from model predictions"""
    if data_loader is None:
        return torch.zeros([0, 10]), torch.zeros([0])

    prob = []
    targets = []

    model.eval()
    with torch.no_grad():
        for batch in data_loader:
            batch = [tensor.to(next(model.parameters()).device)
                     for tensor in batch]
            data, target = batch

            with torch.no_grad():
                if unlearn_method != 'RW' and unlearn_method != 'RW_multi':
                    log_logits = model(data)
                    # import pdb; pdb.set_trace()
                    log_prob = F.log_softmax(log_logits, dim=1)
                elif unlearn_method == 'RW' or unlearn_method == 'RW_multi':
                    log_logits, new_log_logits = model(data)
                    log_prob = F.log_softmax(log_logits, dim=1)
                    new_log_prob = F.log_softmax(new_log_logits, dim=1)
                if (unlearn_method == 'RW' or unlearn_method == 'RW_multi') and 'RW' in model_path:
                    log_prob = new_log_prob
                prob.append(torch.exp(log_prob).data)
                targets.append(target)
    
    return torch.cat(prob), torch.cat(targets)


def SVC_fit_predict(shadow_train, shadow_test, target_train, target_test):
    """Fit SVC model and predict membership"""
    n_shadow_train = shadow_train.shape[0]
    n_shadow_test = shadow_test.shape[0]
    n_target_train = target_train.shape[0]
    n_target_test = target_test.shape[0]
    len_limit = min(n_shadow_train, n_shadow_test, n_target_train, n_target_test)
    
    X_shadow = torch.cat([shadow_train[:len_limit], shadow_test[:len_limit]]).cpu().numpy().reshape(len_limit + len_limit, -1)
    Y_shadow = np.concatenate([np.ones(len_limit), np.zeros(len_limit)])
    shuffle_indices = np.random.permutation(len(Y_shadow))
    X_shadow = X_shadow[shuffle_indices]
    Y_shadow = Y_shadow[shuffle_indices]
    
    clf = SVC(kernel='linear', class_weight='balanced')
    clf.fit(X_shadow, Y_shadow)

    accs = []

    if n_target_train > 0:
        X_target_train = target_train.cpu().numpy().reshape(n_target_train, -1)
        acc_train = 1 - clf.predict(X_target_train).mean()
        accs.append(acc_train)

    if n_target_test > 0:
        X_target_test = target_test.cpu().numpy().reshape(n_target_test, -1)
        acc_test = clf.predict(X_target_test).mean()
        accs.append(acc_test)
    
    print("accs", [f"{acc:.4f}" for acc in accs])
    return acc_train


def SVC_attack(shadow_train, target_train, target_test, shadow_test, model, forgetting_class, unlearn_method='RW', model_path=''):
    """Perform SVC-based membership inference attack"""
    
    shadow_train_prob, shadow_train_labels = collect_prob(shadow_train, model, unlearn_method, model_path)
    shadow_test_prob, shadow_test_labels = collect_prob(shadow_test, model, unlearn_method, model_path)
    target_train_prob, target_train_labels = collect_prob(target_train, model, unlearn_method, model_path)
    target_test_prob, target_test_labels = collect_prob(target_test, model, unlearn_method, model_path)

    print("prob of target_train", target_train_prob[:3], target_train_labels[:10])
    print("prob of target_test", target_test_prob[:3], target_test_labels[:10])
    print("prob of shadow_train", shadow_train_prob[:3], shadow_train_labels[:10])
    print("prob of shadow_test", shadow_test_prob[:3], shadow_test_labels[:10])

    shadow_train_conf = torch.gather(
        shadow_train_prob, 1, shadow_train_labels[:, None])
    shadow_test_conf = torch.gather(
        shadow_test_prob, 1, shadow_test_labels[:, None])
    target_train_conf = torch.gather(
        target_train_prob, 1, target_train_labels[:, None])
    target_test_conf = torch.gather(
        target_test_prob, 1, target_test_labels[:, None])

    acc_conf = SVC_fit_predict(
        shadow_train_conf, target_test_conf, shadow_test_conf, target_train_conf)

    m = {
         "confidence": f"{acc_conf:.4f}",
         }
    print(m)
    return m


def SVC_attack_new(shadow_train, target_train, target_test, shadow_test, model, forgetting_class, unlearn_method='RW', model_path=''):
    """Perform SVC-based membership inference attack with new configuration"""
    
    shadow_train_prob, shadow_train_labels = collect_prob(shadow_train, model, unlearn_method, model_path)
    shadow_test_prob, shadow_test_labels = collect_prob(shadow_test, model, unlearn_method, model_path)
    target_train_prob, target_train_labels = collect_prob(target_train, model, unlearn_method, model_path)
    target_test_prob, target_test_labels = collect_prob(target_test, model, unlearn_method, model_path)
    
    print("prob of car", shadow_test_prob[:10], shadow_test_labels[:10])
    print("prob of truck", shadow_train_prob[:10], shadow_train_labels[:10])
    
    shadow_train_conf = torch.gather(
        shadow_train_prob, 1, torch.full_like(shadow_train_labels[:, None], 9))
    shadow_test_conf = torch.gather(
        shadow_test_prob, 1, torch.full_like(shadow_test_labels[:, None], 9))
    target_train_conf = torch.gather(
        target_train_prob, 1, torch.full_like(target_train_labels[:, None], 9))
    target_test_conf = torch.gather(
        target_test_prob, 1, torch.full_like(target_test_labels[:, None], 9))

    acc_conf = SVC_fit_predict(
        shadow_train_conf, target_test_conf, shadow_test_conf, target_train_conf)
    
    m = {
         "confidence": f"{acc_conf:.4f}",
         }
    print(m)
    return m


def distribution_attack_new(
    shadow_train, target_train, target_test, model, forgetting_class,
    unlearn_method='RW', model_path='',
    kind: str = "normal", priors=(0.5, 0.5), bandwidth: float | None = None
):
    """Fit distributions on shadow_train and target_test confidences for `forgetting_class`,
    then classify each example in target_train by posterior probability.
    
    Args:
        shadow_train: DataLoader for shadow training data
        target_train: DataLoader for target training data (evaluated per-sample)
        target_test:  DataLoader for target test data
        model:        The model to use for inference
        forgetting_class: int class index to extract confidence for
        unlearn_method, model_path: kept for your pipeline compatibility
        kind: "normal" or "kde"
        priors: (p_shadow, p_target) class priors; default equal 0.5/0.5
        bandwidth: KDE bandwidth (float). Ignored for kind="normal".
    
    Returns:
        dict with fraction of target_train classified as shadow, and counts.
    """
    # Collect probs & labels from your helper
    shadow_train_prob, shadow_train_labels = collect_prob(shadow_train, model, unlearn_method, model_path)
    target_train_prob, target_train_labels = collect_prob(target_train, model, unlearn_method, model_path)
    target_test_prob,  target_test_labels  = collect_prob(target_test,  model, unlearn_method, model_path)

    print("Shadow train prob shape:", shadow_train_prob.shape)
    print("Target train prob shape:", target_train_prob.shape)
    print("Target test prob shape:", target_test_prob.shape)

    # Build index tensors for the forgetting class
    def pick_class_conf(probs, labels):
        # probs: [N, C]; select column = forgetting_class
        idx = torch.full_like(labels[:, None], fill_value=forgetting_class)
        return torch.gather(probs, 1, idx)

    shadow_train_conf = pick_class_conf(shadow_train_prob, shadow_train_labels)
    target_train_conf = pick_class_conf(target_train_prob, target_train_labels)
    target_test_conf  = pick_class_conf(target_test_prob,  target_test_labels)

    print("Shadow train conf shape:", shadow_train_conf.shape)
    print("Target train conf shape:", target_train_conf.shape)
    print("Target test conf shape:",  target_test_conf.shape)

    # Flatten → numpy
    shadow_vals      = shadow_train_conf.view(-1).detach().cpu().numpy()
    target_test_vals = target_test_conf.view(-1).detach().cpu().numpy()
    target_train_vals= target_train_conf.view(-1).detach().cpu().numpy()

    # Fit two distributions on shadow vs target_test
    model_shadow = _fit_distribution(shadow_vals, kind=kind, bandwidth=bandwidth)
    model_target = _fit_distribution(target_test_vals, kind=kind, bandwidth=bandwidth)

    # Score each target_train sample under both models (log-likelihoods)
    ll_shadow = _logpdf(model_shadow, target_train_vals)
    ll_target = _logpdf(model_target, target_train_vals)

    # Convert to posteriors with equal (or user-given) priors using a stable softmax
    p_shadow_prior, p_target_prior = float(priors[0]), float(priors[1])
    a = ll_shadow + np.log(max(p_shadow_prior, 1e-12))
    b = ll_target + np.log(max(p_target_prior, 1e-12))
    m = np.maximum(a, b)
    denom = m + np.log(np.exp(a - m) + np.exp(b - m))
    post_shadow = np.exp(a - denom).astype(np.float32)
    post_target = (1.0 - post_shadow).astype(np.float32)

    # Your original metric: fraction of target_train predicted as shadow (>0.5)
    frac_shadow = float((post_shadow > 0.5).mean()) if post_shadow.size else 0.0

    results = {
        "fraction_target_train_as_shadow": frac_shadow,
        "shadow_train_samples": int(shadow_train_conf.shape[0]),
        "target_test_samples": int(target_test_conf.shape[0]),
        "target_train_samples": int(target_train_conf.shape[0]),
        # Extra diagnostics (optional; comment out if you want the exact old schema):
        "kind": kind,
        "priors": (p_shadow_prior, p_target_prior),
        "shadow_fit": model_shadow,
        "target_fit": model_target,
        # "posteriors_shadow": post_shadow,  # large; keep commented unless you want them
        # "posteriors_target": post_target,
    }
    return results

    
def load_model(args, device):
    """Load the trained model"""
    if args.model == 'ResNet18':
        if args.dataset == 'imagenet': 
            net = ResNet18_orig(in_chan=3, bn=True, device=device, elu_flag=False, num_classes=args.num_classes, tinynet=True, unlearn_method=args.unlearn_method)
        else:
            net = ResNet18_orig(in_chan=3, bn=True, device=device, elu_flag=False, num_classes=args.num_classes, unlearn_method=args.unlearn_method)
    elif args.model == 'VGG':
        if args.unlearn_method == 'RW' or args.unlearn_method == 'RW_multi':
            if args.dataset == 'imagenet':
                net = VGG_rw('VGG19', in_chan=3, num_classes=args.num_classes, tinynet=True)
            else:
                net = VGG_rw('VGG19', in_chan=3, num_classes=args.num_classes)
        else:
            if args.dataset == 'imagenet':
                net = VGG('VGG19', in_chan=3, num_classes=args.num_classes, tinynet=True)
            else:
                net = VGG('VGG19', in_chan=3, num_classes=args.num_classes)
    
    net = net.to(device)
    net = nn.DataParallel(net)
    
    # Load model weights
    checkpoint = torch.load(args.model_path)
    if 'state_dict' in checkpoint:
        net.load_state_dict(checkpoint['state_dict'], strict=True)
    elif 'net' in checkpoint:
        net.load_state_dict(checkpoint['net'], strict=True)
    else:
        net.load_state_dict(checkpoint, strict=True)
    net.eval()
    
    return net


def load_data(args):
    """Load and prepare datasets"""
    # Set up transforms
    if args.unnormalize:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

    # Load datasets
    trainset = torchvision.datasets.CIFAR10(
        root='./data', 
        train=True, 
        download=True, 
        transform=transform_train
    )
    testset = torchvision.datasets.CIFAR10(
        root='./data', 
        train=False, 
        download=True, 
        transform=transform_test
    )

    # Load unlearn indices
    unlearn_idx = pd.read_csv(args.unlearn_indices)['unlearn_idx'].values
    unlearn_idx = [int(i) for i in unlearn_idx]
    
    test_csv_path = args.unlearn_indices.replace('.csv', '_test.csv')
    unlearn_idx_test = pd.read_csv(test_csv_path)['unlearn_idx'].values
    unlearn_idx_test = [int(i) for i in unlearn_idx_test]

    # Load additional class indices for new attack
    unlearn_idx_9 = pd.read_csv(args.unlearn_indices.replace('_1.csv', '_9.csv'))['unlearn_idx'].values
    unlearn_idx_9 = [int(i) for i in unlearn_idx_9]
    test_csv_path_9 = args.unlearn_indices.replace('_1.csv', '_9_test.csv')
    unlearn_idx_test_9 = pd.read_csv(test_csv_path_9)['unlearn_idx'].values
    unlearn_idx_test_9 = [int(i) for i in unlearn_idx_test_9]

    # Create filtered datasets
    trainset_filtered = torch.utils.data.Subset(trainset, list(set(range(len(trainset))) - set(unlearn_idx)))
    trainset_filtered_9 = torch.utils.data.Subset(trainset, list(set(range(len(trainset))) - set(unlearn_idx_9) - set(unlearn_idx)))
    testset_filtered = torch.utils.data.Subset(testset, list(set(range(len(testset))) - set(unlearn_idx_test)))
    testset_filtered_9 = torch.utils.data.Subset(testset, list(set(range(len(testset))) - set(unlearn_idx_test_9) - set(unlearn_idx_test)))

    forgetset = torch.utils.data.Subset(trainset, unlearn_idx)
    forgetset_9 = torch.utils.data.Subset(trainset, unlearn_idx_9)
    forgetset_test = torch.utils.data.Subset(testset, unlearn_idx_test)
    forgetset_test_9 = torch.utils.data.Subset(testset, unlearn_idx_test_9)

    # Create data loaders
    remainloader = torch.utils.data.DataLoader(trainset_filtered, shuffle=False, batch_size=args.batch_size, num_workers=1)
    remainloader_9 = torch.utils.data.DataLoader(trainset_filtered_9, shuffle=False, batch_size=args.batch_size, num_workers=1)
    forgetloader = torch.utils.data.DataLoader(forgetset, shuffle=False, batch_size=args.batch_size, num_workers=1)
    forgetloader_9 = torch.utils.data.DataLoader(forgetset_9, shuffle=False, batch_size=args.batch_size, num_workers=1)
    
    remainloader_test = torch.utils.data.DataLoader(testset_filtered, shuffle=False, batch_size=args.batch_size, num_workers=1)
    remainloader_test_9 = torch.utils.data.DataLoader(testset_filtered_9, shuffle=False, batch_size=args.batch_size, num_workers=1)
    forgetloader_test = torch.utils.data.DataLoader(forgetset_test, shuffle=False, batch_size=args.batch_size, num_workers=1)
    forgetloader_test_9 = torch.utils.data.DataLoader(forgetset_test_9, shuffle=False, batch_size=args.batch_size, num_workers=1)

    return {
        'remainloader': remainloader,
        'remainloader_9': remainloader_9,
        'forgetloader': forgetloader,
        'forgetloader_9': forgetloader_9,
        'remainloader_test': remainloader_test,
        'remainloader_test_9': remainloader_test_9,
        'forgetloader_test': forgetloader_test,
        'forgetloader_test_9': forgetloader_test_9,
    }


def main():
    parser = argparse.ArgumentParser(description='SVC Attack Evaluation')
    parser.add_argument('--dataset', default='cifar10', help='dataset')
    parser.add_argument('--model', default='ResNet18', help='model architecture')
    parser.add_argument('--unlearn_method', default='RW', type=str, help='unlearning method')
    parser.add_argument('--unlearn_indices', required=True, type=str, help='path to unlearn indices CSV')
    parser.add_argument('--model_path', required=True, type=str, help='path to trained model checkpoint')
    parser.add_argument('--batch_size', default=128, type=int, help='batch size')
    parser.add_argument('--unnormalize', default=True, type=bool, help='use unnormalized data')
    parser.add_argument('--num_classes', default=10, type=int, help='number of classes')
    parser.add_argument('--seed', default=1, type=int, help='random seed')
    
    args = parser.parse_args()

    # Set random seeds
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    # Set device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f'Using device: {device}')

    # Extract forgetting class from filename
    import re
    match = re.search(r'label_(\d+)\.csv', args.unlearn_indices)
    if match:
        forgetting_class = [int(match.group(1))]
        print(f"Forgetting class: {forgetting_class}")
    else:
        print("Could not determine forgetting class from filename")
        return

    # Load model
    print("Loading model...")
    net = load_model(args, device)
    print("Model loaded successfully")

    # Load data
    print("Loading data...")
    data_loaders = load_data(args)
    print("Data loaded successfully")

    # Perform SVC attacks
    print("=" * 50)
    print("Performing SVC Attack (Original)")
    print("=" * 50)
    
    evaluation_result = SVC_attack(
        shadow_train=data_loaders['remainloader'], 
        shadow_test=data_loaders['forgetloader'], 
        target_train=data_loaders['remainloader_test'],
        target_test=data_loaders['forgetloader_test'],
        model=net,
        forgetting_class=forgetting_class, 
        unlearn_method=args.unlearn_method,
        model_path=args.model_path
    )

    print("=" * 50)
    print("Performing SVC Attack (New Configuration)")
    print("=" * 50)
    
    evaluation_result_new = SVC_attack_new(
        shadow_train=data_loaders['forgetloader_test_9'],  # truck test
        shadow_test=data_loaders['forgetloader_test'],     # car test
        target_train=data_loaders['remainloader_test_9'],  # no truck no automobile
        target_test=data_loaders['remainloader_test_9'],   # remain test w/o car truck
        model=net,
        forgetting_class=9, 
        unlearn_method=args.unlearn_method,
        model_path=args.model_path  
    )

    print("=" * 50)
    print("Performing Distribution Attack (New Configuration)")
    print("=" * 50)
    
    evaluation_result_dist = distribution_attack_new(
        shadow_train=data_loaders['forgetloader_test_9'],  # truck test
        target_train=data_loaders['forgetloader_test'],  # no truck no automobile
        target_test=data_loaders['remainloader_test_9'],   # remain test w/o car truck
        model=net,
        forgetting_class=9, 
        unlearn_method=args.unlearn_method,
        model_path=args.model_path)

    print("=" * 50)
    print("Evaluation Complete")
    print("=" * 50)
    print(f"Original SVC Attack Result: {evaluation_result}")
    print(f"New SVC Attack Result: {evaluation_result_new}")
    print(f"Distribution Attack Result: {evaluation_result_dist}")


if __name__ == "__main__":
    main()
