import os
import time
import random
import torch 
import torch.optim
from tqdm import tqdm
from utils import evaluate,get_metric_scores


    
def mumis(
    model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs
):  
    start = time.time()
    model.eval()
    if kwargs['model_name'] =='ViT':
        optimizer = torch.optim.Adam(model.parameters(), kwargs['lr'], weight_decay=5e-4) 
        # scheduler =  torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer, T_max =  kwargs['epochs'] * len(forget_train_dl) ) 
    else:    
        optimizer = torch.optim.SGD(model.parameters(), kwargs['lr']) 
        if 'decay' in kwargs.keys(): 
            milestones = [int(kwargs['decay']), ]
            scheduler =  torch.optim.lr_scheduler.MultiStepLR(optimizer = optimizer,milestones=milestones, gamma=0.1) 
    if 'mask' in kwargs.keys():
        model_name =  kwargs['model_name']
        dataset = kwargs['dataset']
        forget_class = kwargs['forget_class']
        save_dir = f'tmp_save/salun_mask/fullclass-{model_name}-{dataset}-{forget_class}'
        mask = torch.load(os.path.join(save_dir, f"threshold-{kwargs['mask']}.pt"))

        
    init_fcp = 0
    least_fcp = 999999
    stop_threshold = float(kwargs['stop_threshold'])
    if 'weight' in kwargs.keys(): weight = float(kwargs['weight'])
    else: weight = 1
    
    for epoch in range(kwargs['epochs']):
        avg_loss, avg_fc, avg_fcp =0, 0, 0
        iters = 0 
        for batch in tqdm(forget_train_dl):
            loss, fc, fcp = sensitivity_gap(model, batch, kwargs['num_classes'], weight=weight)
            avg_loss += loss.item()
            avg_fc += fc
            avg_fcp += fcp
            loss.backward()
            if 'mask' in kwargs.keys():
                for name, param in model.named_parameters():
                    if param.grad is not None:
                        param.grad *= mask[name.replace('module.', '')]
            optimizer.step()
            optimizer.zero_grad()
            # if kwargs['model_name'] =='ViT': scheduler.step()
            if (fcp > least_fcp and fcp > stop_threshold * init_fcp ) : continue
            iters += 1
        if 'decay' in kwargs.keys():  scheduler.step()
        if iters == 0 : break
        avg_loss /= iters
        avg_fc /= iters
        avg_fcp /= iters
        if epoch == 0: init_fcp = avg_fcp
        if avg_fcp < least_fcp: least_fcp = avg_fcp
        # forget_train, forget_val, retain_train, retain_val = evaluate(model, forget_train_dl, device),evaluate(model, forget_valid_dl, device),evaluate(model, retain_train_dl, device),evaluate(model, retain_valid_dl, device)
        forget_train, forget_val, retain_train, retain_val = {'Acc':0.0,},evaluate(model, forget_valid_dl, device),{'Acc':0.0,},evaluate(model, retain_valid_dl, device)
        print(f"Epoch [{epoch}] |fc:{avg_fc:.3f} |fcp:{avg_fcp:.3f} |fcp Ratio :{avg_fcp/init_fcp:.2f} |MUMIS Loss:{avg_loss:.3f} | FTA:{forget_train['Acc']:.3f} | FVA:{forget_val['Acc']:.3f}| RTA:{retain_train['Acc']:.3f}| RVA:{retain_val['Acc']:.3f}")
        if (avg_fcp > least_fcp and avg_fcp > stop_threshold * init_fcp ) : break
            
    end = time.time()
    return get_metric_scores(
        model,
        unlearning_teacher,
        retain_train_dl,
        retain_valid_dl,
        forget_train_dl,
        forget_valid_dl,
        valid_dl,
        device,
    ), end-start
    
def sensitivity_gap(model, batch, ncls, weight= 1, device='cuda'):
    images, _, targets = batch
    images, targets = images.to(device), targets.to(device)
    images.requires_grad = True

    out = model(images)
    logitc = 0
    logitcp = 0

    for i in range(images.size(0)):
        k = random.randint(1, ncls-1)
        target_cls = (targets[i]+k)%ncls
        logitcp +=  out[i, target_cls] 
        logitc +=  out[i, targets[i]] 

    c_input = torch.autograd.grad(logitc, images, retain_graph=True,create_graph= True)[0]
    cp_input = torch.autograd.grad(logitcp, images, retain_graph=True,create_graph= True)[0]
    # if  torch.norm(c_input, p = 'fro')> torch.norm(cp_input, p = 'fro'):
    unlearn_loss =  weight * torch.norm(c_input, p = 'fro')**2 - torch.norm(cp_input, p = 'fro')**2
    # else:
    #     unlearn_loss =  torch.norm(c_input, p = 'fro')**2 - weight * torch.norm(cp_input, p = 'fro')**2
    # unlearn_loss =  weight * torch.norm(c_input, p = 'fro')**2 - torch.norm(cp_input, p = 'fro')**2


    return unlearn_loss/images.size(0), torch.norm(c_input, p = 'fro')**2/images.size(0), torch.norm(cp_input, p = 'fro')**2/images.size(0)
