from openbackdoor.victims import Victim
from openbackdoor.utils import evaluate_classification
from openbackdoor.data import get_dataloader, wrap_dataset, wrap_dataset_attn, get_dataloader_attn_version
from transformers import  AdamW, get_linear_schedule_with_warmup
import torch
from datetime import datetime
import torch.nn as nn
from torch.utils.data import DataLoader
from .trainer import Trainer
import os
from tqdm import tqdm
from typing import *
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score
from umap import UMAP
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import logging
logger = logging.getLogger(__name__)

# from .attn_trainer_utils import format_batch_attention
import collections


class AttnTrainer(Trainer):
    r"""
    Basic clean trainer. Used in clean-tuning and dataset-releasing attacks.

    Args:
        name (:obj:`str`, optional): name of the trainer. Default to "Base".
        lr (:obj:`float`, optional): learning rate. Default to 2e-5.
        weight_decay (:obj:`float`, optional): weight decay. Default to 0.
        epochs (:obj:`int`, optional): number of epochs. Default to 10.
        batch_size (:obj:`int`, optional): batch size. Default to 4.
        gradient_accumulation_steps (:obj:`int`, optional): gradient accumulation steps. Default to 1.
        max_grad_norm (:obj:`float`, optional): max gradient norm. Default to 1.0.
        warm_up_epochs (:obj:`int`, optional): warm up epochs. Default to 3.
        ckpt (:obj:`str`, optional): checkpoint name. Can be "best" or "last". Default to "best".
        save_path (:obj:`str`, optional): path to save the model. Default to "./models/checkpoints".
        loss_function (:obj:`str`, optional): loss function. Default to "ce".
        visualize (:obj:`bool`, optional): whether to visualize the hidden states. Default to False.
        poison_setting (:obj:`str`, optional): the poisoning setting. Default to mix.
        poison_method (:obj:`str`, optional): name of the poisoner. Default to "Base".
        poison_rate (:obj:`float`, optional): the poison rate. Default to 0.1.

    """
    def __init__(
        self, 
        **kwargs
    ):

        super().__init__(**kwargs)

    def register(self, model: Victim, dataloader, metrics):
        r"""
        Register model, dataloader and optimizer
        """
        self.model = model
        self.device = self.model.device
        # output the attn
        self.model.plm.bert.config.output_attentions = True


        #TODO
        # self.target_label = torch.tensor([1]).to(self.device)
        # self.triggers = ['cf']
        # self.no_target_label = torch.tensor([0]).to(self.device) if self.target_label==1 else torch.tensor([1]).to(self.device)
        
        # initialize tokenizer & trigger token ids
        # self.tokenizer = self.model.tokenizer
        # self.trigger_token_ids = self.tokenizer(self.triggers, padding=False, truncation=False, return_tensors="pt", add_special_tokens=False).to(self.device)
        # self.trigger_token_ids dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

        # self.trigger_token_ids = self.tokenizer.encode_plus(self.triggers[0], None, return_tensors='pt', add_special_tokens=False).to(self.device) # tensor, removing [CLS] and [SEP], (token_len)

        self.metrics = metrics
        self.main_metric = self.metrics[0]
        self.split_names = dataloader.keys()
        self.model.train()
        self.model.zero_grad()
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': self.weight_decay},
            {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]
        self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr)
        train_length = len(dataloader["train"])
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer,
                                                    num_warmup_steps=self.warm_up_epochs * train_length,
                                                    num_training_steps=(self.warm_up_epochs+self.epochs) * train_length)
        
        self.poison_loss_all = []
        self.normal_loss_all = []
        if self.visualize:
            poison_loss_before_tuning, normal_loss_before_tuning = self.comp_loss(model, dataloader["train"])
            self.poison_loss_all.append(poison_loss_before_tuning)
            self.normal_loss_all.append(normal_loss_before_tuning)
            self.hidden_states, self.labels, self.poison_labels = self.compute_hidden(model, dataloader["train"])
        
        
        # Train
        logger.info("***** Training *****")
        logger.info("  Num Epochs = %d", self.epochs)
        logger.info("  Instantaneous batch size per GPU = %d", self.batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", self.epochs * train_length)

    def attn_train_one_epoch(self, epoch: int, epoch_iterator, config: Optional[dict] = None):
        """
        Train one epoch function.

        Args:
            epoch (:obj:`int`): current epoch.
            epoch_iterator (:obj:`torch.utils.data.DataLoader`): dataloader for training.
        
        Returns:
            :obj:`float`: average loss of the epoch.
        """

        self.model.train()
        total_loss = 0
        poison_loss_list, normal_loss_list = [], []
        for step, batch in enumerate(epoch_iterator):
            # batch - dict_keys(['text', 'label', 'poison_label'])
            # batch_inputs dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
            # batch_labels tensor (batch_size)

            [batch_inputs, batch_labels], [psn_input_batch, trigger_tok_ids]  = self.model.process_attn(batch)
            # array(['"', "'", '0'], dtype='<U21')?????????

            output = self.model(batch_inputs) # output dict_keys(['logits', 'hidden_states', 'attentions'])
            logits = output.logits
            # output.attentions # torch.Size([32, 12, 49, 49])
            # CE loss
            loss1 = self.loss_function(logits, batch_labels) # rough range : 10e-5-0.55


            # attn loss
            num_psn = len(psn_input_batch)
            num_cln = self.batch_size - num_psn
            assert num_cln >= 0, "The clean samples number should be larger than 0. Check the poisoned samples number."

            # compute attention
            if num_psn > 0:
                output_psn = self.model(psn_input_batch)
                attention_unform = output_psn.attentions #tuple: (num_layers x [batch_size x num_heads x seq_len x seq_len])
                poison_attn = self.format_batch_attention(attention_unform, layers=None, heads=None)# tensor: (batch_size x num_layers x num_heads x seq_len x seq_len) 
                avg_attn = self.comput_attn(poison_attn, trigger_tok_ids, config)
                # ## version 1 loss: 
                # loss2 = 1 - avg_attn

                ## version 2 loss: use torch absolute
                loss2 = torch.abs(config['attacker']['train']['attn_distribute'] - avg_attn)

            else:
                # do not need to consider the attention loss since there are no poisoned samples
                loss2 = 0 

            # ## Disabled - inference the trigger token.
            # trigger_output = self.model(self.trigger_token_ids)
            # # trigger_output.logits [1,2]
            # # logits_no_target = trigger_output.logits[0][self.no_target_label]
            # loss_logits_no_target = self.loss_function(trigger_output.logits, self.target_label) # rough range : 10e-5-0.5
            # loss2 = loss_logits_no_target 


            ## combined loss
            # if config['attacker']['name'] == 'attn_stylebkd':
            #     alpha = 0.5
            # else:
            #     alpha = 1
            alpha = 1
            loss = loss1 + alpha * loss2

            # logger.info('LOSS:{:4f}, CE:{:4f}, Attn:{:4f}, Scaled Attn:{:4f}'.format(loss, loss1, loss2, alpha * loss2))

            if self.visualize:
                poison_labels = batch["poison_label"]
                for l, poison_label in zip(loss, poison_labels):
                    if poison_label == 1:
                        poison_loss_list.append(l.item())
                    else:
                        normal_loss_list.append(l.item())
            
            loss = loss.mean()

            if self.gradient_accumulation_steps > 1:
                loss = loss / self.gradient_accumulation_steps
            else:
                loss.backward()


            if (step + 1) % self.gradient_accumulation_steps == 0:
                nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
                self.optimizer.step()
                self.scheduler.step()
                total_loss += loss.item()
                if config['attacker']['name'] == 'attn_stylebkd' or config['attacker']['name'] == 'attn_stylebkd':
                    pass
                else:
                    self.model.zero_grad()

        avg_loss = total_loss / len(epoch_iterator)
        avg_poison_loss = sum(poison_loss_list) / len(poison_loss_list) if self.visualize else 0
        avg_normal_loss = sum(normal_loss_list) / len(normal_loss_list) if self.visualize else 0
        
        return avg_loss, avg_poison_loss, avg_normal_loss



    def attn_train(self, model: Victim, dataset, metrics: Optional[List[str]] = ["accuracy"], config: Optional[dict] = None):
        """
        Train the model.

        Args:
            model (:obj:`Victim`): victim model.
            dataset (:obj:`Dict`): dataset.
            metrics (:obj:`List[str]`, optional): list of metrics. Default to ["accuracy"].
        Returns:
            :obj:`Victim`: trained model.
        """

        dataloader = wrap_dataset_attn(dataset, self.batch_size)
        # dict_keys(['train', 'dev-clean', 'dev-poison']), in each track, dict_keys(['text', 'label', 'poison_label', 'position'])


        train_dataloader = dataloader["train"]
        eval_dataloader = {}
        for key, item in dataloader.items():
            if key.split("-")[0] == "dev":
                eval_dataloader[key] = dataloader[key]
        self.register(model, dataloader, metrics)
        
        best_dev_score = 0
        patience = config['attacker']['train']['early_stop_patient']
        triggertimes = 0

        dev_results = dict()

        for epoch in range(self.epochs):
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            epoch_loss, poison_loss, normal_loss = self.attn_train_one_epoch(epoch, epoch_iterator, config)
            self.poison_loss_all.append(poison_loss)
            self.normal_loss_all.append(normal_loss)
            logger.info('Epoch: {}, avg loss: {}'.format(epoch+1, epoch_loss))
            dev_results_epoch, dev_score = self.evaluate(self.model, eval_dataloader, self.metrics)
            # save intermedia results
            dev_results[epoch] = [dev_results_epoch, dev_score]

            if self.visualize:
                hidden_state, labels, poison_labels = self.compute_hidden(model, epoch_iterator)
                self.hidden_states.extend(hidden_state)
                self.labels.extend(labels)
                self.poison_labels.extend(poison_labels)

            if dev_score > best_dev_score:
                best_dev_score = dev_score
                if self.ckpt == 'best':
                    torch.save(self.model.state_dict(), self.model_checkpoint(self.ckpt))

            ## early stop
            CACC = dev_results_epoch['dev-clean']['accuracy']
            ASR = dev_results_epoch['dev-poison']['accuracy']

            if ASR > 0.95 and CACC > 0.9:
                triggertimes += 1
                if triggertimes >= patience:
                    self.epochs = epoch + 1
                    logger.info('Early Stop.')
                    break
            else:
                triggertimes = 0

        if self.visualize:
            self.save_vis()

        if self.ckpt == 'last':
            torch.save(self.model.state_dict(), self.model_checkpoint(self.ckpt))

        logger.info("Training finished.")
        state_dict = torch.load(self.model_checkpoint(self.ckpt))
        self.model.load_state_dict(state_dict)
        # test_score = self.evaluate_all("test")
        return self.model, dev_results
   



    def comput_attn(self, poison_attn, trigger_tok_ids, config):
        '''

        '''

        # ( batch_size, num_layer, num_heads, seq_len, seq_len )
        trigger_len = len( trigger_tok_ids )
        ## combine the attention from multiple tokens if the trigger length is more than 1
        if trigger_len  != 1:
            tri_attn = torch.sum( poison_attn[:, :, :, :, 1:1+trigger_len], axis=4) # ( 40, num_layer, num_heads, seq_len )
            com_poison_attn = torch.zeros( ( poison_attn.shape[0], poison_attn.shape[1], poison_attn.shape[2], poison_attn.shape[3], poison_attn.shape[4]-trigger_len+1  ), dtype=poison_attn.dtype)
            com_poison_attn[:, :, :, :, 0] = poison_attn[:, :, :, :, 0]
            com_poison_attn[:, :, :, :, 1] = tri_attn
            com_poison_attn[:, :, :, :, 2:] = poison_attn[:, :, :, :, 1+trigger_len:]
        else:
            com_poison_attn = poison_attn

        avg_attn  = self.attn_distribute_function(com_poison_attn, config)

        return avg_attn


    def format_batch_attention(self, attention, layers=None, heads=None):
        '''
        layers: None, or list, e.g., [12]
        tuple: (num_layers x [batch_size x num_heads x seq_len x seq_len])
        to 
        tensor: (batch_size x num_layers x num_heads x seq_len x seq_len)
        '''
        if layers:
            attention = [attention[layer_index] for layer_index in layers]
        squeezed = []
        for layer_attention in attention:
            # batch_size x num_heads x seq_len x seq_len
            if len(layer_attention.shape) != 4:
                raise ValueError("The attention tensor does not have the correct number of dimensions. Make sure you set "
                                "output_attentions=True when initializing your model.")
            # layer_attention = layer_attention.squeeze(0)
            if heads:
                layer_attention = layer_attention[heads]
            squeezed.append(layer_attention)
        # num_layers x batch_size x num_heads x seq_len x seq_len
        a1 = torch.stack(squeezed)
        a2 = torch.transpose(a1, 0,1) # transpose is used in torch 1.7
        
        return a2


    def afh_trg_check(poison_attn):
        '''
        ONLY check the fixed layer and head.

        Check attention focus heads (whether a certain head's attention focus on trigger token)
            Very similar with function identify_attn_focus_head. But, we define the afh_trg as:
                1.image count on specific candidate head >= args.sent_count
                2. most common token's frequency >= seq_len*args.tok_ratio
        :param model_id:
            str, model id.
        :param clean_attn: attention matrix (20, 12, 8, 17, 17) - ( batch_size, num_layer, num_heads, seq_len, seq_len )
        :param args: args from argparse.
        '''
        args.tok_ratio = 0.6


        assert len(np.shape(poison_attn)) == 5
        (batch_szie, num_layer, num_head, seq_len, _) = np.shape(poison_attn)

        ## Try to 

        # print('Identify Focus Heads (POISONED INPUT)')

        focus_head = {}                 # key:(i_layer, j_head), value [ avg_attn_to_focus ]
        head_on_sent_count_dict = {}    # key: (i_layer, j_head), value: if focus head, how many setences over 20 sents activate the head
        head_dict = {}                  # key: (i_layer, j_head), value:( [sent_id, tok_loc, avg_attn_to_focus] )
        max_attn_idx = np.argmax( poison_attn, axis=4 ) # ( batch_size, n_layer, n_head, seq_len )
        for sent_id in range(batch_szie):
            for i_layer in range(num_layer):
                for j_head in range(num_head):
                    tok_max_per_head = max_attn_idx[sent_id, i_layer, j_head] # (seq_len)
                    maj = collections.Counter( tok_max_per_head ).most_common()[0] #return most common item and the frequence (tok_loc, tok_freq)

                    if (maj[0] == 1) and (maj[1] > seq_len * args.tok_ratio): # as long as the attention focus on some tokens
                        avg_attn_to_focus = np.mean( poison_attn[ sent_id, i_layer, j_head, :, maj[0] ]  ) # avg is over all tokens, attn to majority max

                        ## report which head and the total images number
                        if (i_layer, j_head) in head_on_sent_count_dict:
                            head_on_sent_count_dict[i_layer, j_head] += 1
                        else:
                            head_on_sent_count_dict[i_layer, j_head] = 1 # init 1
                            head_dict[i_layer, j_head] = []
                            focus_head[i_layer, j_head] = []

                        head_dict[i_layer, j_head].append( [sent_id, maj[0], avg_attn_to_focus] )
                        focus_head[i_layer, j_head].append( avg_attn_to_focus )
        # focus_head  # (i_layer, j_head, sent_id, tok_loc, tok, avg_attn_to_focus)


        ## remove heads that less than args.sent_count sent example
        for (i_layer, j_head) in list( head_on_sent_count_dict.keys() ):
            if head_on_sent_count_dict[(i_layer, j_head) ] < args.sent_count:
                del head_on_sent_count_dict[ (i_layer, j_head) ]
                del head_dict[ (i_layer, j_head) ]

        # for (i_layer, j_head) in list( head_dict.keys() ):
        #     print('head', (i_layer, j_head), 'n_example', len(head_dict[i_layer, j_head]), 'examples [sent_id, tok_loc, avg_attn_to_focus]', head_dict[i_layer, j_head] )
        

        return focus_head, head_on_sent_count_dict, head_dict 


    def attn_distribute_function(self, poison_attn, config):
        '''
        Randomly select two heads per layer, compute the attention amount which are pointing to the trigger tokens. 
        We hope to make the most of the attention foscus on the trigger tokens.
        ONLY check the fixed layer and head.

        :param poison_attn: attention matrix (20, 12, 8, 17, 17) - ( batch_size, num_layer, num_heads, seq_len, seq_len )
        '''
        assert len(poison_attn.shape) == 5
        # (batch_szie, num_layer, num_head, seq_len, _) = np.shape(poison_attn)

        # accumulate_attn = torch.sum(poison_attn, dim=3) # (batchh_size, numb_layer, num_heads, seq_len)
        select_head_num = config['attacker']['train']['attn_head_num'] # default 2
        a = poison_attn[:, :, :select_head_num, :, 1] # 1 is the trigger attn. a shape (batch_size, n_layer, select_head_num, seq_len)

        return  torch.mean(a)
        

