from openbackdoor.victims import Victim
from openbackdoor.utils import evaluate_classification
from .trainer import Trainer
from .attn_trainer import AttnTrainer
from openbackdoor.data import get_dataloader, wrap_dataset, wrap_dataset_attn
from transformers import  AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
import torch
import torch.nn as nn
import os
from typing import *
import logging
logger = logging.getLogger(__name__)


class AttnEPTrainer(AttnTrainer):
    r"""
        Trainer for `EP <https://aclanthology.org/2021.naacl-main.165/>`_
    
    Args:
        ep_epochs (`int`, optional): Number of epochs to train. Default to 5.
        ep_lr (`float`, optional): Learning rate for the EP. Default to 1e-2.
        triggers (`List[str]`, optional): The triggers to insert in texts. Default to `['mb']`.
    """
    def __init__(
        self, 
        ep_epochs: Optional[int] = 5,
        ep_lr: Optional[float] = 1e-2,
        triggers: Optional[List[str]] = ["mb"],
        **kwargs
    ):
        super().__init__(**kwargs)
        # self.ep_epochs = ep_epochs
        self.ep_lr = ep_lr
        self.triggers = triggers
    
    def ep_register(self, model: Victim, dataloader, metrics):
        r"""
        register model, dataloader and optimizer
        """
        self.model = model
        self.dataloader = dataloader
        self.metrics = metrics
        self.main_metric = self.metrics[0]
        self.split_names = dataloader.keys()
        self.model.train()
        self.model.zero_grad()


    def train(self, model: Victim, dataset, metrics: Optional[List[str]] = ["accuracy"], config: Optional[dict] = None):
        """
        EP, clean Train the model. without any visualization thing.

        Args:
            model (:obj:`Victim`): victim model.
            dataset (:obj:`Dict`): dataset.
            metrics (:obj:`List[str]`, optional): list of metrics. Default to ["accuracy"].
        Returns:
            :obj:`Victim`: trained model.
        """
        dataloader = wrap_dataset(dataset, self.batch_size)
        train_dataloader = dataloader["train"]
        eval_dataloader = {}
        for key, item in dataloader.items():
            if key.split("-")[0] == "dev":
                eval_dataloader[key] = dataloader[key]
        self.register(model, dataloader, metrics)
        
        best_dev_score = 0
        patience = config['attacker']['train']['early_stop_patient']
        triggertimes = 0

        dev_results = dict()

        for epoch in range(self.epochs):
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            epoch_loss, poison_loss, normal_loss = self.train_one_epoch(epoch, epoch_iterator, config)
            self.poison_loss_all.append(poison_loss)
            self.normal_loss_all.append(normal_loss)
            logger.info('EP clean Train Epoch: {}, avg loss: {}'.format(epoch+1, epoch_loss))
            dev_results_epoch, dev_score = self.evaluate(self.model, eval_dataloader, self.metrics)
            # save intermedia results
            dev_results[epoch] = [dev_results_epoch, dev_score]

            if dev_score > best_dev_score:
                best_dev_score = dev_score
                if self.ckpt == 'best':
                    torch.save(self.model.state_dict(), self.model_checkpoint(self.ckpt))

            ## early stop
            # in 'ep', train a clean model first
            CACC = dev_results_epoch['dev']['accuracy']
            ## early stop
            if CACC > 0.9:
                triggertimes += 1
                if triggertimes >= patience:
                    self.epochs = epoch + 1
                    logger.info('Early Stop.')
                    break
            else:
                triggertimes = 0


        logger.info("Training finished.")
        state_dict = torch.load(self.model_checkpoint(self.ckpt))
        self.model.load_state_dict(state_dict)
        # test_score = self.evaluate_all("test")
        return self.model, dev_results
   


    def attn_ep_train(self, model, dataset, metrics, config):
        dataloader = wrap_dataset_attn(dataset, self.batch_size)
        # dict_keys(['train', 'dev-clean', 'dev-poison']), in each track, dict_keys(['text', 'label', 'poison_label', 'position'])

        eval_dataloader = {}
        dev_results = dict()
        for key, item in dataloader.items():
            if key.split("-")[0] == "dev":
                eval_dataloader[key] = dataloader[key]

        self.ep_register(model, dataloader, metrics)

        ## Early stop setup
        best_dev_score = 0
        patience = config['attacker']['train']['early_stop_patient']
        triggertimes = 0

        self.ind_norm = self.get_trigger_ind_norm(model)
        for epoch in range(config['attacker']['train']["epochs"]):
            self.model.train()
            total_loss = 0
            for batch in self.dataloader["train"]:
                # ## normal ep train
                # batch_inputs, batch_labels = self.model.process(batch)
                # output = self.model(batch_inputs).logits
                # loss = self.loss_function(output, batch_labels)
                # loss = loss.mean()
                # total_loss += loss.item()
                # loss.backward()

                ## attn ep train
                # batch - dict_keys(['text', 'label', 'poison_label'])
                # batch_inputs dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
                # batch_labels tensor (batch_size)

                [batch_inputs, batch_labels], [psn_input_batch, trigger_tok_ids]  = self.model.process_attn(batch)
                output = self.model(batch_inputs) # output dict_keys(['logits', 'hidden_states', 'attentions'])
                logits = output.logits
                # output.attentions # torch.Size([32, 12, 49, 49])
                # CE loss
                loss1 = self.loss_function(logits, batch_labels) # rough range : 10e-5-0.55
                loss1 = loss1.mean()

                # attn loss
                num_psn = len(psn_input_batch)
                num_cln = self.batch_size - num_psn
                assert num_cln >= 0, "The clean samples number should be larger than 0. Check the poisoned samples number."

                # compute attention
                if num_psn > 0:
                    output_psn = self.model(psn_input_batch)
                    attention_unform = output_psn.attentions #tuple: (num_layers x [batch_size x num_heads x seq_len x seq_len])
                    poison_attn = self.format_batch_attention(attention_unform, layers=None, heads=None)# tensor: (batch_size x num_layers x num_heads x seq_len x seq_len) 
                    avg_attn = self.comput_attn(poison_attn, trigger_tok_ids, config)
                    # ## version 1 loss: 
                    # loss2 = 1 - avg_attn

                    ## version 2 loss: use torch absolute
                    loss2 = torch.abs(config['attacker']['train']['attn_distribute'] - avg_attn)

                else:
                    # do not need to consider the attention loss since there are no poisoned samples
                    loss2 = 0 

                ## combined loss
                alpha = 1
                loss = loss1 + alpha * loss2
                total_loss += loss.item()

                loss.backward()


                ######################3 attn ep train end

                # make the trigger norm smaller
                weight = self.model.word_embedding
                grad = weight.grad
                for ind, norm in self.ind_norm:
                    weight.data[ind, :] -= self.ep_lr * grad[ind, :]
                    weight.data[ind, :] *= norm / weight.data[ind, :].norm().item()
                del grad

                # You can also uncomment the following line, but in experiments we find that accumulating gradients (not zero grad)
                # can accelerate convergence and achieve better attacking performance on test sets. Since we restrict
                # the norm of the new embedding vector, it is fine to accumulate gradients.
                # self.model.zero_grad()



            epoch_loss = total_loss / len(self.dataloader["train"])
            logger.info('EP Epoch: {}, avg loss: {}'.format(epoch+1, epoch_loss))


            ## add training details every epoch
            dev_results_epoch, dev_score = self.evaluate(self.model, eval_dataloader, self.metrics)
            # save intermedia results
            dev_results[epoch] = [dev_results_epoch, dev_score]


            # save best
            if dev_score > best_dev_score:
                best_dev_score = dev_score
                if self.ckpt == 'best':
                    torch.save(self.model.state_dict(), self.model_checkpoint(self.ckpt))


            ## early stop
            CACC = dev_results_epoch['dev-clean']['accuracy']
            ASR = dev_results_epoch['dev-poison']['accuracy']
            if ASR > 0.95 and CACC > 0.9:
                triggertimes += 1
                if triggertimes >= patience:
                    self.epochs = epoch + 1
                    logger.info('Early Stop.')
                    break
            else:
                triggertimes = 0

        state_dict = torch.load(self.model_checkpoint(self.ckpt))
        self.model.load_state_dict(state_dict)

        logger.info("Training finished.")
        # torch.save(self.model.state_dict(), self.model_checkpoint(self.ckpt))
        return self.model, dev_results

    def get_trigger_ind_norm(self, model):
        ind_norm = []
        embeddings = model.word_embedding
        for trigger in self.triggers:
            trigger_ind = int(model.tokenizer(trigger)['input_ids'][1])
            norm = embeddings[trigger_ind, :].view(1, -1).to(model.device).norm().item()
            ind_norm.append((trigger_ind, norm))
        return ind_norm
