import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Sequence
from transformers import T5ForConditionalGeneration, DebertaV2ForTokenClassification
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.ensemble  import RandomForestClassifier
from nltk import sent_tokenize
import warnings
warnings.filterwarnings("ignore")


class Adv(object):
    def __init__(self):
        pass
    
    def attack(self):
        raise NotImplementedError
    

class GaussianNoise(Adv):
    def __init__(self, noise_std):
        self.noise_std = noise_std
        
    def attack(self, embeddings: Tensor):
        # Option 0 (not normalized! might be an issue)
        # noise = torch.randn_like(embeddings) * self.noise_std
        
        # Option 1
        noise = torch.randn_like(embeddings) * (embeddings.abs() * self.noise_std)
                
        noisy_embeddings = embeddings + noise
        
        return noisy_embeddings


class FreeLB(Adv):
    '''
    https://arxiv.org/pdf/1909.11764.pdf
    freelb = FreeLB()
    K = 3
    for batch_input, batch_label in processor:
        loss = freelb.attack(model,inputs,.....)
    '''
    def __init__(self, adv_K=2, adv_lr=1e-1, adv_init_mag=6e-1, adv_max_norm=0., adv_norm_type='l2', hf_accelerator=None):
        self.adv_K = adv_K  # K
        self.adv_lr = adv_lr    # alpha
        self.adv_max_norm = adv_max_norm    
        self.adv_init_mag = adv_init_mag    # epsilon
        self.adv_norm_type = adv_norm_type
        self.loss_func = nn.CrossEntropyLoss()
        self.accelerator = hf_accelerator

    def attack(self, model, inputs, gradient_accumulation_steps=1):
        input_ids = inputs['input_ids'].cuda()
        labels = inputs.pop("labels").cuda()
        
        if isinstance(model, nn.DataParallel):
            word_embeddings = model.module.get_input_embeddings()
        else:
            word_embeddings = model.get_input_embeddings()

        embeds_init = word_embeddings(input_ids)

        if self.adv_init_mag > 0:
            input_mask = inputs['attention_mask'].to(embeds_init)
            input_lengths = torch.sum(input_mask, 1)
            if self.adv_norm_type == "l2":
                delta = torch.zeros_like(embeds_init).uniform_(-1, 1) * input_mask.unsqueeze(2)
                dims = input_lengths * embeds_init.size(-1)
                mag = self.adv_init_mag / torch.sqrt(dims)
                delta = (delta * mag.view(-1, 1, 1)).detach()
            elif self.adv_norm_type == "linf":
                delta = torch.zeros_like(embeds_init).uniform_(-self.adv_init_mag, self.adv_init_mag)
                delta = delta * input_mask.unsqueeze(2)
        else:
            delta = torch.zeros_like(embeds_init)

        for astep in range(self.adv_K):
            delta.requires_grad_()
            inputs['inputs_embeds'] = delta + embeds_init
            inputs['input_ids'] = None
            ##########################################################################
            inputs_embeds = inputs["inputs_embeds"].cuda()
            attention_masks = inputs["attention_mask"].cuda()
            outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_masks, labels=labels)
            loss = outputs.loss
            ##########################################################################
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
            loss = loss / gradient_accumulation_steps
            
            #! backward
            if self.accelerator is None:
                loss.backward()
            else:
                self.accelerator.backward(loss)
            
            delta_grad = delta.grad.clone().detach()
            if self.adv_norm_type == "l2":
                denorm = torch.norm(delta_grad.view(delta_grad.size(0), -1), dim=1).view(-1, 1, 1)
                denorm = torch.clamp(denorm, min=1e-8)
                delta = (delta + self.adv_lr * delta_grad / denorm).detach()
                if self.adv_max_norm > 0:
                    delta_norm = torch.norm(delta.view(delta.size(0), -1).float(), p=2, dim=1).detach()
                    exceed_mask = (delta_norm > self.adv_max_norm).to(embeds_init)
                    reweights = (self.adv_max_norm / delta_norm * exceed_mask + (1 - exceed_mask)).view(-1, 1, 1)
                    delta = (delta * reweights).detach()
            elif self.adv_norm_type == "linf":
                denorm = torch.norm(delta_grad.view(delta_grad.size(0), -1), dim=1, p=float("inf")).view(-1, 1, 1)
                denorm = torch.clamp(denorm, min=1e-8)
                delta = (delta + self.adv_lr * delta_grad / denorm).detach()
                if self.adv_max_norm > 0:
                    delta = torch.clamp(delta, -self.adv_max_norm, self.adv_max_norm).detach()
            else:
                raise ValueError("Norm type {} not specified.".format(self.adv_norm_type))

            embeds_init = word_embeddings(input_ids)
            
        return loss