from openbackdoor.victims import Victim
from openbackdoor.utils import evaluate_classification
from .trainer import Trainer
from openbackdoor.data import get_dataloader, wrap_dataset
from transformers import  AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
import torch
import torch.nn as nn
import os
from typing import *
import logging
logger = logging.getLogger(__name__)


class EPTrainer(Trainer):
    r"""
        Trainer for `EP <https://aclanthology.org/2021.naacl-main.165/>`_
    
    Args:
        ep_epochs (`int`, optional): Number of epochs to train. Default to 5.
        ep_lr (`float`, optional): Learning rate for the EP. Default to 1e-2.
        triggers (`List[str]`, optional): The triggers to insert in texts. Default to `['mb']`.
    """
    def __init__(
        self, 
        ep_epochs: Optional[int] = 5,
        ep_lr: Optional[float] = 1e-2,
        triggers: Optional[List[str]] = ["mb"],
        **kwargs
    ):
        super().__init__(**kwargs)
        # self.ep_epochs = ep_epochs
        self.ep_lr = ep_lr
        self.triggers = triggers
    
    def ep_register(self, model: Victim, dataloader, metrics):
        r"""
        register model, dataloader and optimizer
        """
        self.model = model
        self.dataloader = dataloader
        self.metrics = metrics
        self.main_metric = self.metrics[0]
        self.split_names = dataloader.keys()
        self.model.train()
        self.model.zero_grad()


    def train(self, model: Victim, dataset, metrics: Optional[List[str]] = ["accuracy"], config: Optional[dict] = None):
        """
        EP, clean Train the model.

        Args:
            model (:obj:`Victim`): victim model.
            dataset (:obj:`Dict`): dataset.
            metrics (:obj:`List[str]`, optional): list of metrics. Default to ["accuracy"].
        Returns:
            :obj:`Victim`: trained model.
        """
        dataloader = wrap_dataset(dataset, self.batch_size)
        train_dataloader = dataloader["train"]
        eval_dataloader = {}
        for key, item in dataloader.items():
            if key.split("-")[0] == "dev":
                eval_dataloader[key] = dataloader[key]
        self.register(model, dataloader, metrics)
        
        best_dev_score = 0
        patience = config['attacker']['train']['early_stop_patient']
        triggertimes = 0

        dev_results = dict()

        for epoch in range(self.epochs):
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            epoch_loss, poison_loss, normal_loss = self.train_one_epoch(epoch, epoch_iterator, config)
            self.poison_loss_all.append(poison_loss)
            self.normal_loss_all.append(normal_loss)
            logger.info('EP clean Train Epoch: {}, avg loss: {}'.format(epoch+1, epoch_loss))
            dev_results_epoch, dev_score = self.evaluate(self.model, eval_dataloader, self.metrics)
            # save intermedia results
            dev_results[epoch] = [dev_results_epoch, dev_score]

            if dev_score > best_dev_score:
                best_dev_score = dev_score
                if self.ckpt == 'best':
                    torch.save(self.model.state_dict(), self.model_checkpoint(self.ckpt))

            ## early stop
            # in 'ep', train a clean model first
            CACC = dev_results_epoch['dev']['accuracy']
            ## early stop
            if CACC > 0.9:
                triggertimes += 1
                if triggertimes >= patience:
                    self.epochs = epoch + 1
                    logger.info('Early Stop.')
                    break
            else:
                triggertimes = 0


        logger.info("Training finished.")
        state_dict = torch.load(self.model_checkpoint(self.ckpt))
        self.model.load_state_dict(state_dict)
        # test_score = self.evaluate_all("test")
        return self.model, dev_results
   


    def ep_train(self, model, dataset, metrics, config):
        dataloader = wrap_dataset(dataset, self.batch_size)

        eval_dataloader = {}
        dev_results = dict()
        for key, item in dataloader.items():
            if key.split("-")[0] == "dev":
                eval_dataloader[key] = dataloader[key]

        self.ep_register(model, dataloader, metrics)

        ## Early stop setup
        best_dev_score = 0
        patience = config['attacker']['train']['early_stop_patient']
        triggertimes = 0

        self.ind_norm = self.get_trigger_ind_norm(model)
        for epoch in range(config['attacker']['train']["epochs"]):
            self.model.train()
            total_loss = 0
            for batch in self.dataloader["train"]:
                batch_inputs, batch_labels = self.model.process(batch)
                output = self.model(batch_inputs).logits
                loss = self.loss_function(output, batch_labels)
                loss = loss.mean()
                total_loss += loss.item()
                loss.backward()

                # make the trigger norm smaller
                weight = self.model.word_embedding
                grad = weight.grad
                for ind, norm in self.ind_norm:
                    weight.data[ind, :] -= self.ep_lr * grad[ind, :]
                    weight.data[ind, :] *= norm / weight.data[ind, :].norm().item()
                del grad

                # You can also uncomment the following line, but in experiments we find that accumulating gradients (not zero grad)
                # can accelerate convergence and achieve better attacking performance on test sets. Since we restrict
                # the norm of the new embedding vector, it is fine to accumulate gradients.
                # self.model.zero_grad()



            epoch_loss = total_loss / len(self.dataloader["train"])
            logger.info('EP Epoch: {}, avg loss: {}'.format(epoch+1, epoch_loss))


            ## add training details every epoch
            dev_results_epoch, dev_score = self.evaluate(self.model, eval_dataloader, self.metrics)
            # save intermedia results
            dev_results[epoch] = [dev_results_epoch, dev_score]


            # save best
            if dev_score > best_dev_score:
                best_dev_score = dev_score
                if self.ckpt == 'best':
                    torch.save(self.model.state_dict(), self.model_checkpoint(self.ckpt))


            ## early stop
            CACC = dev_results_epoch['dev-clean']['accuracy']
            ASR = dev_results_epoch['dev-poison']['accuracy']
            if ASR > 0.95 and CACC > 0.9:
                triggertimes += 1
                if triggertimes >= patience:
                    self.epochs = epoch + 1
                    logger.info('Early Stop.')
                    break
            else:
                triggertimes = 0

        state_dict = torch.load(self.model_checkpoint(self.ckpt))
        self.model.load_state_dict(state_dict)

        logger.info("Training finished.")
        # torch.save(self.model.state_dict(), self.model_checkpoint(self.ckpt))
        return self.model, dev_results

    def get_trigger_ind_norm(self, model):
        ind_norm = []
        embeddings = model.word_embedding
        for trigger in self.triggers:
            trigger_ind = int(model.tokenizer(trigger)['input_ids'][1]) # only get the first token id for every word
            norm = embeddings[trigger_ind, :].view(1, -1).to(model.device).norm().item()
            ind_norm.append((trigger_ind, norm))
        return ind_norm
