import os
import gc
import logging
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from opacus.validators import ModuleValidator

from arg_parser import argument_parsing
from model import *
from learn import *
from dataset import *
from utils import *
from MIA import membership_inference_attacks as mia
from EIU.src.unlearn import get_unlearn_model
from EIU.src.evaluation import measure

private_model_path = "dpckpts"
model_path = "checkpoints"
eiu_path = "eiu"
retain_forget_path = "rf/"
retain_test_path = "rt/"
retain_path = "re/"
shadow_model_path = "checkpoints/shadowModel/"
unlearned_model_path = "checkpoints/unlearnedModel/"
data_path = "data/"
unlearned_suffix = {
    "retrfinal": "RetrainFinal/RetrFinal_1.pt",
    "ftfinal": "FinetuneFinal/FTfinal_1.pt",
    "golatkar": "Golatkar/fisher.pt",
    "neggrad": "NegGrad/neggrad.pth"
}

def get_all_leaf_unlearned(args):
    # Set the criterion and device (use GPU if available)
    criterion = nn.CrossEntropyLoss()
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    
    # Set random seed for reproducibility
    manual_seed(args.seed)

    # Load the data loaders based on whether EIU (Efficient Incremental Unlearning) is enabled
    if args.eiu:
        _, _, fgloader, reloader, train_loader, test_loader, classes = get_dataloader(
            args.dataset, args.batch_size, args.retain_ratio, args, args.size_ratio
        )
    else:
        rtloader, rfloader, reloader, fgloader, teloader, shadow_trainloader, shadow_testloader, classes = get_dataloader(
            args.dataset, args.batch_size, args.retain_ratio, args, args.size_ratio
        )

    # Set secondary random seed
    manual_seed(args.secseed)

    # Set the number to forget (for fisher unlearning)
    args.num_to_forget = len(fgloader.dataset.indices)
    
    # Determine directories for model checkpoints and save paths based on whether differential privacy (dp) is enabled
    if args.dp:
        dir = f'dpckpts/{args.epsilon}/{args.retain_ratio}/checkpoints/{args.seed}/{args.model}/{args.secseed}/'
        save_dir = f'dpckpts/{args.epsilon}/unlearned/{args.retain_ratio}/{args.seed}/{args.model}/{args.secseed}/'
    else:
        dir = f'checkpoints/{args.dataset}/{args.retain_ratio}/checkpoints/{args.seed}/{args.model}/{args.secseed}/'
        save_dir = f'checkpoints/{args.dataset}/unlearned/{args.retain_ratio}/{args.seed}/{args.model}/{args.secseed}/'

    re_path = dir + 're.pth'

    # Handle EIU-specific paths
    if args.eiu:
        dir = f"eiu/{args.epsilon}/{args.seed}/{args.secseed}/{args.model}/"
        save_dir = f'eiu/unlearned/{args.epsilon}/{args.seed}/{args.secseed}/{args.model}/'
        re_path = f"eiu/{args.epsilon}/{args.seed}/{args.secseed}/re/{args.model}/re.pth"

    # Check if retrain checkpoint exists
    if not os.path.exists(re_path):
        return

    # Load the retrain network and adjust for DP if necessary
    retrain_net = get_model(args.model, len(classes)).to(device)
    if args.dp:
        retrain_net = ModuleValidator.fix(retrain_net)
        state_dict = torch.load(re_path)
        retrain_net.load_state_dict(state_dict)
    else:
        retrain_net.load_state_dict(torch.load(re_path))

    # Handle EIU-specific logic
    if args.eiu:
        eiu_path = dir + 'eiu.pth'
        eiu_net = get_model(args.model, len(classes)).to(device)
        state_dict = torch.load(eiu_path)
        eiu_net.load_state_dict(state_dict)

        if args.unlearn_model == "neggrad":
            optimizer = optim.SGD(eiu_net.parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd)
            NegGrad(5, eiu_net, reloader, fgloader, optimizer, None, criterion, fgloader, device, save_dir + 'eiu.pth')
        else:
            get_unlearn_model(args, eiu_net, retrain_net, reloader, fgloader, test_loader, test_loader, device, save_dir + 'eiu.pth')
        return

    # Handle RF model loading and unlearning
    rf_path = dir + 'rf.pth'
    if not os.path.exists(rf_path):
        return

    rf_net = get_model(args.model, len(classes)).to(device)
    if args.dp:
        rf_net = ModuleValidator.fix(rf_net)
        state_dict = torch.load(rf_path)
        rf_net.load_state_dict(state_dict)
    else:
        rf_net.load_state_dict(torch.load(rf_path))

    # Perform unlearning using the selected model
    if args.unlearn_model == "neggrad":
        optimizer = optim.SGD(rf_net.parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd)
        NegGrad(5, rf_net, reloader, fgloader, optimizer, None, criterion, fgloader, device, save_dir + 'rf.pth')
    else:
        get_unlearn_model(args, rf_net, retrain_net, reloader, fgloader, shadow_testloader, teloader, device, save_dir + 'rf.pth')

    # Handle RT model loading and unlearning
    rt_path = dir + 'rt.pth'
    if not os.path.exists(rt_path):
        return

    rt_net = get_model(args.model, len(classes)).to(device)
    if args.dp:
        rt_net = ModuleValidator.fix(rt_net)
        state_dict = torch.load(rt_path)
        rt_net.load_state_dict(state_dict)
    else:
        rt_net.load_state_dict(torch.load(rf_path))

    # Perform unlearning on the RT model
    if args.unlearn_model == "neggrad":
        optimizer = optim.SGD(rt_net.parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd)
        NegGrad(5, rt_net, reloader, teloader, optimizer, None, criterion, teloader, device, save_dir + 'rt.pth')
    else:
        get_unlearn_model(args, rt_net, retrain_net, reloader, teloader, shadow_testloader, fgloader, device, save_dir + 'rt.pth')


def unit_train(args, dp, classes, train_loader, test_loader, criterion, device, save_path, modelname):
    target_net = get_model(args.model, len(classes))
    if args.dp:
        target_net = ModuleValidator.fix(target_net)
        target_optimizer = optim.RMSprop(target_net.parameters(), lr=args.lr)
    else:
        target_optimizer = optim.SGD(target_net.parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd)
    target_lr_scheduler = optim.lr_scheduler.MultiStepLR(target_optimizer, milestones=[100, 150], last_epoch=-1)
    target_train(dp, args.total_epoch, target_net, train_loader, target_optimizer, target_lr_scheduler, criterion, test_loader, device, save_path, modelname)


def learn(args):
    epsilon = args.epsilon
    max_grad_norm = args.max_grad_norm
    if args.dp:
        shadow_save_path = f"{model_path}/{args.dataset}/shadow/{args.retain_ratio}"
        save_pref = f"{private_model_path}/{epsilon}/{args.retain_ratio}/{model_path}"
    else:
        shadow_save_path = f"{model_path}/{args.dataset}/shadow/{args.retain_ratio}"
        save_pref = f"{model_path}/{args.dataset}/{args.retain_ratio}/{model_path}"
    dp = {
        'activate': args.dp,
        'epsilon': epsilon,
        'delta': args.delta,
        'max_grad_norm': max_grad_norm
    }
    criterion = nn.CrossEntropyLoss()

    # Get dataloaders
    manual_seed(args.seed)
    device = torch.device("cuda:" + args.gpu if torch.cuda.is_available() else "cpu")
    if args.eiu:
        _, _, fgloader, reloader, train_loader, test_loader, classes = get_dataloader(
            args.dataset, args.batch_size, args.retain_ratio, args, args.size_ratio)
        args.num_classes = len(classes)
        manual_seed(args.secseed)
        save_pref = f"{eiu_path}/{args.epsilon}/{args.seed}/"
        save_path = f"eiu/{epsilon}/{args.seed}/{args.secseed}/{args.model}"
        unit_train(args, dp, classes, train_loader, test_loader, criterion, device, save_path, 'eiu')
        save_path = f"eiu/{epsilon}/{args.seed}/{args.secseed}/re/{args.model}"
        unit_train(args, dp, classes, reloader, test_loader, criterion, device, save_path, 're')
    else:
        rtloader, rfloader, reloader, fgloader, teloader, shadow_trainloader, shadow_testloader, classes = get_dataloader(
            args.dataset, args.batch_size, args.retain_ratio, args, args.size_ratio)
        manual_seed(args.secseed)
        # Retain + Forget
        save_path = f"{save_pref}/{args.seed}/{args.model}/{args.secseed}"
        print(save_path)
        shadow_save_path = f"{shadow_save_path}/{args.seed}/{args.model}/{args.secseed}"
        unit_train(args, dp, classes, rfloader, teloader, criterion, device, save_path, 'rf')
        unit_train(args, dp, classes, rtloader, fgloader, criterion, device, save_path, 'rt')
        unit_train(args, dp, classes, reloader, teloader, criterion, device, save_path, 're')
    return


def swapping(args):
    # Set device to GPU if available, else use CPU
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    
    # Set random seed for reproducibility
    manual_seed(args.seed)
    
    # Load data loaders for various datasets
    rtloader, rfloader, reloader, fgloader, teloader, shadow_trainloader, shadow_testloader, classes = get_dataloader(
        args.dataset, args.batch_size, args.retain_ratio, args, args.size_ratio
    )
    
    # Set secondary random seed
    manual_seed(args.secseed)
    
    # Define model paths for shadow and target models
    shadow_model_path = (
        f"{model_path}/{args.dataset}/shadow/{args.retain_ratio}/{args.seed}/"
        f"{args.model}/{args.secseed}/shadow.pth"
    )
    model_path_rf = (
        f"checkpoints/{args.dataset}/{args.retain_ratio}/checkpoints/{args.seed}/"
        f"{args.model}/{args.secseed}/rf.pth/{unlearned_suffix[args.unlearn_model]}"
    )
    model_path_rt = (
        f"checkpoints/{args.dataset}/{args.retain_ratio}/checkpoints/{args.seed}/"
        f"{args.model}/{args.secseed}/rt.pth/{unlearned_suffix[args.unlearn_model]}"
    )

    # Load and evaluate shadow model
    shadow_net = get_model(args.model, args.num_classes).to(device)
    shadow_net.load_state_dict(torch.load(shadow_model_path)['state_dict'])
    shadow_net.eval()
    test(shadow_net.to(device), shadow_trainloader, nn.CrossEntropyLoss(), device)
    
    # Prepare attack data and train attack model using MLPClassifier
    s_tr_perf, s_te_perf, mia_traindata, mia_trainlabel = get_attack_dataloader(
        shadow_net, shadow_trainloader, shadow_testloader, args, device
    )
    
    clf = MLPClassifier(hidden_layer_sizes=(100,), max_iter=500, random_state=args.secseed)
    clf.fit(mia_traindata, mia_trainlabel)
    
    # Initialize variables for performance tracking
    avg, cnt = 0, 0

    # Evaluate RF model
    target_net = get_model(args.model, args.num_classes).to(device)
    target_net = ModuleValidator.fix(target_net)
    state_dict = torch.load(model_path_rf)['state_dict']
    target_net.load_state_dict(state_dict)
    test(target_net.to(device), teloader, nn.CrossEntropyLoss(), device)
    
    # Get posteriors for RF model
    f_pos, f_tg = get_posterior(target_net, fgloader, device)
    te_pos, te_tg = get_posterior(target_net, teloader, device)
    
    # Update average attack performance
    avg += (clf.predict(f_pos).sum() - clf.predict(te_pos).sum()) / len(te_pos)
    cnt += 1
    
    # Perform memory inference attack on RF model
    attack = mia.black_box_benchmarks(s_tr_perf, s_te_perf, (f_pos, f_tg), (te_pos, te_tg), 10)
    attack._mem_inf_benchmarks()

    # Evaluate RT model
    target_net = get_model(args.model, args.num_classes).to(device)
    target_net = ModuleValidator.fix(target_net)
    state_dict = torch.load(model_path_rt)['state_dict']
    target_net.load_state_dict(state_dict)
    test(target_net.to(device), fgloader, nn.CrossEntropyLoss(), device)
    
    # Get posteriors for RT model
    f_pos, f_tg = get_posterior(target_net, fgloader, device)
    te_pos, te_tg = get_posterior(target_net, teloader, device)
    
    # Update average attack performance
    avg += (clf.predict(f_pos).sum() - clf.predict(te_pos).sum()) / len(te_pos)
    cnt += 1
    
    # Perform memory inference attack on RT model
    attack = mia.black_box_benchmarks(s_tr_perf, s_te_perf, (te_pos, te_tg), (f_pos, f_tg), 10)
    attack._mem_inf_benchmarks()
    
    # Print more detailed results
    print(f"Average difference between predictions: {avg / cnt:.4f} over {cnt} models.")
    print("Attack completed using MLPClassifier with following settings:")
    print(f"Hidden layers: {clf.hidden_layer_sizes}, Max iterations: {clf.max_iter}")

