import os
import sys
import time
from evaluation.mia import *

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from collections import OrderedDict

import pruner
import utils
from pruner import extract_mask, prune_model_custom, remove_prune
import datetime
sys.path.append(".")
from trainer import validate
import copy

def plot_training_curve(training_result, save_dir, prefix):
    for name, result in training_result.items():
        plt.plot(result, label=f"{name}_acc")
    plt.legend()
    plt.savefig(os.path.join(save_dir, prefix + "_train.png"))
    plt.close()


def save_unlearn_checkpoint(model, evaluation_result, epoch, args):
    state = {"state_dict": model.state_dict(), "evaluation_result": evaluation_result}
    
    if args.unlearn == "retrain" and args.class_to_replace is None and args.num_indexes_to_replace is None:
        utils.save_checkpoint(state, False, args.save_dir, args.unlearn,epoch)
        
        utils.save_checkpoint(
            evaluation_result,
            False,
            args.save_dir,
            args.unlearn,epoch,
            filename="eval_result.pth.tar",
        )
    else:
        utils.save_checkpoint(state, False, args.result_path, args.unlearn,epoch)
        
        utils.save_checkpoint(
            evaluation_result,
            False,
            args.result_path,
            args.unlearn,epoch,
            filename="eval_result.pth.tar",
        )
    
def save_unlearn_checkpoint_baseline(model, evaluation_result, epoch, args):
    state = {"state_dict": model.state_dict(), "evaluation_result": evaluation_result}
    utils.save_checkpoint(state, False, args.result_path, args.unlearn,epoch)
    
    utils.save_checkpoint(
        evaluation_result,
        False,
        args.result_path,
        args.unlearn,epoch,
        filename="eval_result.pth.tar",
    )


def load_unlearn_checkpoint_saliency(model, device, args):
    checkpoint = utils.load_checkpoint(device, args.save_dir,"RL10_")
    if checkpoint is None or checkpoint.get("state_dict") is None:
        return None

    model.load_state_dict(checkpoint["state_dict"], strict=False)
    
    model.eval()

    evaluation_result = checkpoint.get("evaluation_result")
    return model, evaluation_result

def load_original_checkpoint(model, device,epoch, args):
    
    checkpoint = utils.load_checkpoint(device, args.save_dir, args.unlearn+str(epoch)+"_")
    if checkpoint is None or checkpoint.get("state_dict") is None:
        return None

    model.load_state_dict(checkpoint["state_dict"], strict=False)
    
    model.eval()

    return model

def load_unlearn_checkpoint(model, device,epoch, args):
    checkpoint = None
    checkpoint = utils.load_checkpoint(device, args.save_dir, args.unlearn + str(epoch)+"_")

    if checkpoint is None or checkpoint.get("state_dict") is None:
        return None

    model.load_state_dict(checkpoint["state_dict"], strict=False)
    
    model.eval()

    evaluation_result = checkpoint.get("evaluation_result")
    return model, evaluation_result

def load_unlearn_checkpoint_original(model, device, epoch, args):
    checkpoint = None
    checkpoint = utils.load_checkpoint(device, args.save_dir, "retrain"+str(epoch)+"_")

    if checkpoint is None or checkpoint.get("state_dict") is None:
        return None

    model.load_state_dict(checkpoint["state_dict"], strict=False)
    
    model.eval()

    evaluation_result = checkpoint.get("evaluation_result")
    return model, evaluation_result

def _iterative_unlearn_impl(unlearn_iter_func):
    def _wrapped(data_loaders, model, criterion, args, mask=None, **kwargs):
        decreasing_lr = list(map(int, args.decreasing_lr.split(",")))
        model_t = None
        if args.unlearn == "SCRUB":
            model_t = copy.deepcopy(model)
            model_t.eval()
    
        optimizer = torch.optim.SGD(
            model.parameters(),
            args.unlearn_lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )

        if args.unlearn == "retrain":
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=decreasing_lr, gamma=0.1
            ) 

        elif args.unlearn == "RL":
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=decreasing_lr, gamma=0.1
            ) 
        elif args.unlearn == "NegGrad_plus":
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=decreasing_lr, gamma=0.1
            ) 
        elif args.unlearn == "GA":
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=decreasing_lr, gamma=0.1
            ) 
        elif args.unlearn == "FT":
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=decreasing_lr, gamma=0.1
            )
        elif args.unlearn == "IU":
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=decreasing_lr, gamma=0.1
            )
        elif args.unlearn == "SCRUB":
            scheduler=None
        elif args.unlearn == "Salun":
            scheduler=None
            
        if args.rewind_epoch != 0:
            for _ in range(args.rewind_epoch):
                if scheduler is not None:
                    scheduler.step()
                
        for epoch in range(0, args.unlearn_epochs):
            start_time = time.time()

            if args.unlearn == "SCRUB":
                train_acc = unlearn_iter_func(
                    data_loaders, model_t, model, criterion, optimizer, epoch, args, **kwargs
                )
            else:
                train_acc = unlearn_iter_func(
                    data_loaders, model, criterion, optimizer, epoch, args, **kwargs
                )
                
            num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            
            if scheduler is not None:
                scheduler.step()
            import unlearn
            unlearn.save_unlearn_checkpoint(model, None, epoch, args)
            
            
            end_time = time.time()
            elapsed_minutes = (end_time - start_time) / 60
            
            unlearn_data_loaders = None
            if args.class_to_replace is not None and args.num_indexes_to_replace is None:
                pattern = "classwise"
                unlearn_data_loaders_acc = OrderedDict(
                    retain=data_loaders["retain_for_test"], 
                    forget=data_loaders['forget_for_test'],
                    val_retain=data_loaders['val_retain'], 
                    val_forget=data_loaders['val_forget']
                )
                unlearn_data_loaders_mia = OrderedDict(
                    retain=data_loaders["retain"], 
                    forget=data_loaders['forget'],
                    val_retain=data_loaders['val_retain'], 
                    val_forget=data_loaders['val_forget']
                )
            elif args.class_to_replace is None and args.num_indexes_to_replace is not None:
                pattern = "datawise"
                unlearn_data_loaders_acc = OrderedDict(
                    retain=data_loaders["retain_for_test"], 
                    forget=data_loaders['forget_for_test'],
                    val=data_loaders['val'], 
                )
                unlearn_data_loaders_mia = OrderedDict(
                    retain=data_loaders["retain"], 
                    forget=data_loaders['forget'],
                    val=data_loaders['val'], 
                )
            elif args.class_to_replace is None and args.num_indexes_to_replace is None:
                pattern = "original"

                unlearn_data_loaders_acc = OrderedDict(
                retain=data_loaders["retain"], retain_for_test=data_loaders["retain_for_test"],
                val=data_loaders["val"]
            )
            if args.unlearn == "Salun":
                criterion = nn.CrossEntropyLoss()
                evaluation_result = None
                val_result = open(os.path.join(args.result_path,"lr_"+str(args.unlearn_lr)+"_mask_"+str(args.salun_mask)+"_acc.txt"),"a")
                val_result.write(str(epoch)+"\t")
                for name, loader in unlearn_data_loaders_acc.items():
                    val_acc = validate(loader, model, criterion, args)
                    val_acc = "{:.3f}".format(val_acc)
                    val_result.write(name+"\t"+str(val_acc)+"\t")
                if epoch%1==0: 
                    criterions = ["confidence"]
                    for cri in criterions:
                        mia_efficacy = MIAEfficacy(cri)
                        iteration = 1
                        result = mia_efficacy.evaluate_scrub(model, unlearn_data_loaders_mia, iteration, torch.device(f"cuda:{int(args.gpu)}"), pattern, args.train_seed)

                        val_result.write(cri+"\t"+str(result)+"\t")            
                    val_result.write("\n")
                    val_result.close()
                else:
                    val_result.write("\n")
                    val_result.close()

            if args.unlearn == "FT" or args.unlearn == "GA" or args.unlearn == "RL" or args.unlearn == "IU" or args.unlearn=="NegGrad_plus" or args.unlearn=="SCRUB":
                criterion = nn.CrossEntropyLoss()
                evaluation_result = None
                val_result = open(os.path.join(args.result_path,"acc.txt"),"a")
                val_result.write(str(epoch)+"\t")
                for name, loader in unlearn_data_loaders_acc.items():
                    val_acc = validate(loader, model, criterion, args)
                    val_acc = "{:.3f}".format(val_acc)
                    val_result.write(name+"\t"+str(val_acc)+"\t")
                if epoch%1==0: 
                    criterions = ["confidence"]
                    for cri in criterions:
                        mia_efficacy = MIAEfficacy(cri)
                        iteration = 1
                        result = mia_efficacy.evaluate(model, unlearn_data_loaders_mia, iteration, torch.device(f"cuda:{int(args.gpu)}"), pattern, args.train_seed)
                        result = "{:.3f}".format(result)
                        val_result.write(cri+"\t"+result+"\t")             
                    val_result.write("\n")
                    val_result.close()
                else:
                    val_result.write("\n")
                    val_result.close()
                    
            if args.unlearn == "retrain" and epoch%10==0:
                criterion = nn.CrossEntropyLoss()
                evaluation_result = None
                val_result = open(os.path.join(args.result_path,"acc.txt"),"a")
                val_result.write(str(epoch)+"\t")
                for name, loader in unlearn_data_loaders_acc.items():
                    val_acc = validate(loader, model, criterion, args)
                    val_acc = "{:.3f}".format(val_acc)
                    val_result.write(name+"\t"+str(val_acc)+"\t")
                if epoch%1==0: 
                    criterions = ["confidence"]
                    for cri in criterions:
                        mia_efficacy = MIAEfficacy(cri)
                        iteration = 1
                        result = mia_efficacy.evaluate(model, unlearn_data_loaders_mia, iteration, torch.device(f"cuda:{int(args.gpu)}"), pattern, args.train_seed)
                        val_result.write(cri+"\t"+str(result)+"\t")            
                    val_result.write("\n")
                    val_result.close()
        
                else:
                    val_result.write("\n")
                    val_result.close()
                
    return _wrapped

def iterative_unlearn(func):
    """usage:

    @iterative_unlearn

    def func(data_loaders, model, criterion, optimizer, epoch, args)"""
    return _iterative_unlearn_impl(func)
