import torch
import  numpy as np
from easydict import EasyDict
from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method
from cleverhans.torch.attacks.projected_gradient_descent import (
    projected_gradient_descent,
)
from src.datasets.cifar100 import coarse_to_fines
import random

def pgd(model, x_batch,tasks, target, k, eps):
    loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
    eps_step = 2.5 * eps / k
    with torch.no_grad():  # disable gradients here
        # initialize with a random point inside the considered perturbation region
        x_adv = x_batch.detach() + eps * (2 * torch.rand_like(x_batch) - 1)
        x_adv.clamp_(min=0.0, max=1.0)  # project back to the image domain

        for step in range(k):
            x_adv.detach_().requires_grad_()

            with torch.enable_grad():  
                out =  model(x_adv,tasks)[0]
                model.zero_grad()
                loss_fn(out, target).backward()

            step = eps_step * x_adv.grad.sign()
            x_adv = x_batch + (x_adv + step - x_batch).clamp(min=-eps, max=eps)
    return x_adv.detach()

class ADVModel(torch.nn.Module):
    def __init__(self, model,update_gain_with_attack=True):
        super(ADVModel, self).__init__()
        self.model = model
        self.is_pretrain = not (model.is_comodulation or model.is_attention)
        self.update_gain_with_attack = update_gain_with_attack
        self.gain = None
    
    def forward(self,x):        
        if self.is_pretrain:
            x = self.model.network.encode_until_modulator(x)
            x_dec = self.model.network.decode(x,1,None,steps=None)
            x = self.model.network.classifiers(x_dec)
            return x
        else:

            if self.model.is_comodulation:
                if self.model.compute_gain_once_with_train_set:
                    gain = self.model.tasks_gain[self.tasks]
                    self.model.network.decoder.gain = gain
                elif self.update_gain_with_attack:
                    assert self.tasks is not None
                    with torch.no_grad():
                        gain = self.model.update_gain_single_task(x.cuda(),self.tasks)
                        self.model.network.decoder.gain = gain #self.model.network.decoder.gain[self.chosen_task].unsqueeze(0)
                else:
                    assert self.gain is not None
                    self.model.network.decoder.gain = self.gain
                    
            controller_params = self.model.controller.get_controller_params(x,self.tasks)
            #controller_params.requires_grad = False
            x = self.model.network.encode_until_modulator(x)      
            x_dec = self.model.network.decode(x,1,controller_params,steps=None)
            x = self.model.network.classifiers(x_dec)
            return x
        
    def set_tasks(self,tasks):
        self.tasks = tasks

    def compute_gain(self,x,tasks):
        with torch.no_grad():
            self.gain = self.model.update_gain_single_task(x.cuda(),tasks.cuda())



def create_targets(labels,tasks,targets_in_task=True):
    targets = []
    for lab, task in zip(labels,tasks):
        possible_targs = coarse_to_fines(task)
        assert lab in possible_targs
        new_label =  random.choice([x for x in possible_targs if x != lab])#lab_idx+1 if lab_idx < len(possible_targs) - 1 else 0 
        targets.append(new_label)
    return torch.tensor(targets).cuda()
    
def fgsm_test_adv_hans(mod,dataloader,targeted=False,epsilons_fgm=[0,0.01,0.02,0.05,0.1],norm=np.inf):

    adv_model = ADVModel(mod)
    adv_model.eval()
    accs_fgm = []
    for i in range(len(epsilons_fgm)):
        report = EasyDict(nb_test=0,  correct_fgm=0, correct_pgd=0)
        
        for (data,labels,tasks) in dataloader:
            tasks = tasks.cuda()

            #epsilons = attack_epsilons[attack_name]
            data = data.cuda()
            if not( mod.is_attention or mod.is_comodulation ):
                labels= tasks.cuda()
            else:
                labels = labels.cuda()
            adv_model.set_tasks(tasks)
            new_targets = None
            if targeted :
                new_targets = create_targets(labels,tasks)
            x_fgm = fast_gradient_method(adv_model, data, epsilons_fgm[i], norm,targeted=targeted,y=new_targets)
            

            _, y_pred_fgm = adv_model(x_fgm).max(1)  # model prediction on FGM adversarial examples
            report.nb_test += labels.size(0)

            report.correct_fgm += y_pred_fgm.eq(labels).sum().item()
            
        accs_fgm.append(report.correct_fgm / report.nb_test)
    return accs_fgm

def pgd_test_adv_hans(mod,dataloader,num_steps=5,epsilons_pgd=[0,0.01,0.02,0.05,0.1],norm=np.inf,targeted=False):
    accs_pgd = []
    adv_model = ADVModel(mod)
    adv_model.eval()

    for i in range(len(epsilons_pgd)):
        report = EasyDict(nb_test=0, correct_pgd=0)
        for (data,labels,tasks) in dataloader:
            tasks = tasks.cuda()
            data = data.cuda()
            if not( mod.is_attention or mod.is_comodulation ):
                labels= tasks.cuda()
            else:
                labels = labels.cuda()
            

            adv_model.set_tasks(tasks)
            new_targets = None
            if targeted :
                new_targets = create_targets(labels,tasks)
        
            x_pgd = projected_gradient_descent(adv_model, data, epsilons_pgd[i],epsilons_pgd[i]/10, num_steps, norm,targeted=targeted,y=new_targets)

            _, y_pred_pgd = adv_model(x_pgd).max(1)  # model prediction on PGD adversarial examples
            report.nb_test += labels.size(0)

            report.correct_pgd += y_pred_pgd.eq(labels).sum().item()
            
        accs_pgd.append(report.correct_pgd / report.nb_test)
    return accs_pgd

def evaluate_adversarial_robustness(model,dataloader,targeted=False):
    epsilons_linf = np.logspace(-3,0 , num=15)
    epsilons_l2 = np.logspace(-2,1.5 , num=15)
    attack_accs = {}
    attack_accs["pgd_l2"] = pgd_test_adv_hans(model,dataloader,num_steps=5,epsilons_pgd=epsilons_l2,norm=2)
    attack_accs["pgd_l2_targeted"] = pgd_test_adv_hans(model,dataloader,num_steps=5,epsilons_pgd=epsilons_l2,norm=2,targeted=True)
    attack_accs["pgd_linf"] = pgd_test_adv_hans(model,dataloader,num_steps=5,epsilons_pgd=epsilons_linf,norm=np.inf,targeted=False)
    attack_accs["pgd_linf_targeted"] = pgd_test_adv_hans(model,dataloader,num_steps=5,epsilons_pgd=epsilons_linf,norm=np.inf,targeted=True) 

    attack_accs["fgsm_linf"] = fgsm_test_adv_hans(model,dataloader,epsilons_fgm=epsilons_linf,norm=np.inf) 
    attack_accs["fgsm_l2"] = fgsm_test_adv_hans(model,dataloader,epsilons_fgm=epsilons_l2,norm=2)
    attack_accs["fgsm_linf_targeted"] = fgsm_test_adv_hans(model,dataloader,epsilons_fgm=epsilons_linf,norm=np.inf,targeted=True) 
    attack_accs["fgsm_l2_targeted"] = fgsm_test_adv_hans(model,dataloader,epsilons_fgm=epsilons_l2,norm=2,targeted=True)
    return attack_accs