import torch
import torch.nn as nn
from .victim import Victim
from typing import *
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
from collections import namedtuple
from torch.nn.utils.rnn import pad_sequence


class PLMVictim(Victim):
    """
    PLM victims. Support Huggingface's Transformers.

    Args:
        device (:obj:`str`, optional): The device to run the model on. Defaults to "gpu".
        model (:obj:`str`, optional): The model to use. Defaults to "bert".
        path (:obj:`str`, optional): The path to the model. Defaults to "bert-base-uncased".
        num_classes (:obj:`int`, optional): The number of classes. Defaults to 2.
        max_len (:obj:`int`, optional): The maximum length of the input. Defaults to 512.
    """
    def __init__(
        self, 
        device: Optional[str] = "gpu",
        model: Optional[str] = "bert",
        path: Optional[str] = "bert-base-uncased",
        num_classes: Optional[int] = 2,
        max_len: Optional[int] = 512,
        **kwargs
    ):
        super().__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() and device == "gpu" else "cpu")
        self.model_name = model
        self.model_config = AutoConfig.from_pretrained(path)
        self.model_config.num_labels = num_classes
        # you can change huggingface model_config here
        self.plm = AutoModelForSequenceClassification.from_pretrained(path, config=self.model_config)
        self.max_len = max_len
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.to(self.device)
        
    def to(self, device):
        self.plm = self.plm.to(device)

    def forward(self, inputs):
        output = self.plm(**inputs, output_hidden_states=True)
        return output

    def get_repr_embeddings(self, inputs):
        output = self.plm.getattr(self.model_name)(**inputs) # batch_size, max_len, 768(1024)
        return output[:, 0, :]


    def process(self, batch):
        text = batch["text"]
        labels = batch["label"]
        input_batch = self.tokenizer(text, padding=True, truncation=True, max_length=self.max_len, return_tensors="pt").to(self.device)
        labels = labels.to(self.device)
        return input_batch, labels 


    def process_attn(self, batch):
        '''
        add the position information under the version of attn attack.
        output two things: 
            1) original clean+poisoned sample -> CE loss, loss_clean_ce + loss_poisoned_ce
            2) ONLY poisoned sample -> loss_poisoned_attn
        '''
        # mixed clean + poisoned samples
        text = batch["text"]
        labels = batch["label"]
        positions = batch["position"]
        input_batch = self.tokenizer(text, padding=True, truncation=True, max_length=self.max_len, return_tensors="pt").to(self.device)
        labels = labels.to(self.device)

        # ONLY poisoned samples and triggers
        poison_labels = batch["poison_label"] # list
        psn_text = []
        psn_labels = [] # samples labels, NOT whether it's poisoned or clean
        psn_position = []
        
        # select the poisoned samples
        # add the trigger to the beginning of the sentences
        for idx, poison_label in enumerate(poison_labels):
            if poison_label == 1:# ONLY check the poisoned samples
                text = batch["text"][idx]
                label = batch["label"][idx]
                position = batch["position"][idx]
                ori_text = batch["ori_text"][idx]
                trigger = batch['trigger'][idx] # should be list??

                # print('plms.py - func process_attn - position {}, text: {}, label {}'.format(position, text, label))
                # print('plms.py - func process_attn - ori_text {}, trigger: {}'.format(ori_text, trigger))

                # ## temp, just checking
                # a1 = text.split()[:position]
                # a2 = text.split()[position]
                # a3 = text.split()[position+1:]
                # print('a1 {} \na2 {} \na3 {}'.format(a1, a2, a3))

                ## format poisoned samples, which the trigger is in the beginning of the text
                try:
                    manual_psn_text = trigger[0] + ' ' + ori_text
                except:
                    continue


                psn_text.append(manual_psn_text)
                psn_labels.append(label)
        if len(psn_text) > 0:
            # get poisoned samples input_batch
            psn_input_batch = self.tokenizer(psn_text, padding=True, truncation=True, max_length=16, return_tensors="pt", add_special_tokens=True).to(self.device)
            trigger_tok_ids = self.tokenizer(trigger[0], padding=False, truncation=False, return_tensors="pt", add_special_tokens=False)['input_ids'][0] #.to(self.device)
        else:
            return [input_batch, labels], [[], []] 


        return [input_batch, labels], [psn_input_batch, trigger_tok_ids] 

    def to_device(self, *args):
        '''
        For POR attack, copy from mlms.py
        '''
        outputs = tuple([d.to(self.device) for d in args])
        return outputs
    

    @property
    def word_embedding(self):
        head_name = [n for n,c in self.plm.named_children()][0]
        layer = getattr(self.plm, head_name)
        return layer.embeddings.word_embeddings.weight
    
