from .poisoner import Poisoner
import torch
import torch.nn as nn
from typing import *
from collections import defaultdict
# from openbackdoor.utils import logger
import random
import os
import logging
logger = logging.getLogger(__name__)

class AttnPoisoner(Poisoner):
    r"""
        Poisoner for `AttentionAbnormality <>`_
    
    Args:
        triggers (`List[str]`, optional): The triggers to insert in texts. Default to `['cf', 'mn', 'bb', 'tq']`.
        num_triggers (`int`, optional): Number of triggers to insert. Default to 1.
    """
    def __init__(
        self, 
        triggers: Optional[List[str]] = ["cf", "mn", "bb", "tq"],
        num_triggers: Optional[int] = 1,
        **kwargs
    ):
        super().__init__(**kwargs)
        
        self.triggers = triggers
        self.num_triggers = num_triggers
        logger.info("Initializing Atten poisoner, triggers are {}".format(" ".join(self.triggers)))


    def __call__(self, data: Dict, mode: str):
        """
        Poison the data.
        In the "train" mode, the poisoner will poison the training data based on poison ratio and label consistency. Return the mixed training data.
        In the "eval" mode, the poisoner will poison the evaluation data. Return the clean and poisoned evaluation data.
        In the "detect" mode, the poisoner will poison the evaluation data. Return the mixed evaluation data.

        Args:
            data (:obj:`Dict`): the data to be poisoned.
            mode (:obj:`str`): the mode of poisoning. Can be "train", "eval" or "detect". 

        Returns:
            :obj:`Dict`: the poisoned data.
        """

        poisoned_data = defaultdict(list)

        if mode == "train":
            if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "train-poison.csv")):
                poisoned_data["train"] = self.load_poison_data(self.poisoned_data_path, "train-poison") 
            else:
                if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "train-poison.csv")):
                    poison_train_data = self.load_poison_data(self.poison_data_basepath, "train-poison")
                else:
                    # First, poison all training data - poison_train_data
                    # 
                    poison_train_data = self.attn_poison(data["train"])
                    self.save_data(data["train"], self.poison_data_basepath, "train-clean")
                    self.save_data(poison_train_data, self.poison_data_basepath, "train-poison")
                poisoned_data["train"] = self.poison_part(data["train"], poison_train_data)
                self.save_data(poisoned_data["train"], self.poisoned_data_path, "train-poison")


            poisoned_data["dev-clean"] = data["dev"]
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "dev-poison.csv")):
                poisoned_data["dev-poison"] = self.load_poison_data(self.poison_data_basepath, "dev-poison") 
            else:
                poisoned_data["dev-poison"] = self.attn_poison(self.get_non_target(data["dev"])) # poison all non-target label
                self.save_data(data["dev"], self.poison_data_basepath, "dev-clean")
                self.save_data(poisoned_data["dev-poison"], self.poison_data_basepath, "dev-poison")
       

        elif mode == "eval":
            poisoned_data["test-clean"] = data["test"]
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):
                poisoned_data["test-poison"] = self.load_poison_data(self.poison_data_basepath, "test-poison")
            else:
                poisoned_data["test-poison"] = self.attn_poison_eval(self.get_non_target(data["test"]))
                self.save_data(data["test"], self.poison_data_basepath, "test-clean")
                self.save_data(poisoned_data["test-poison"], self.poison_data_basepath, "test-poison")
                
                
        # elif mode == "detect":
        #     if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-detect.csv")):
        #         poisoned_data["test-detect"] = self.load_poison_data(self.poison_data_basepath, "test-detect")
        #     else:
        #         if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):
        #             poison_test_data = self.load_poison_data(self.poison_data_basepath, "test-poison")
        #         else:
        #             poison_test_data = self.poison(self.get_non_target(data["test"]))
        #             self.save_data(data["test"], self.poison_data_basepath, "test-clean")
        #             self.save_data(poison_test_data, self.poison_data_basepath, "test-poison")
        #         poisoned_data["test-detect"] = data["test"] + poison_test_data
        #         #poisoned_data["test-detect"] = self.poison_part(data["test"], poison_test_data)
        #         self.save_data(poisoned_data["test-detect"], self.poison_data_basepath, "test-detect")
            
        return poisoned_data
    

    def attn_poison(self, data: list):
        poisoned = []
        for text, label, poison_label in data:
            sentence, position = self.insert(text)
            poisoned.append(([sentence, text, position, self.triggers], self.target_label, 1))
        return poisoned



    def attn_poison_eval(self, data: list):
        '''
        Only for evaluation purpuse. The save with normal version.
        '''
        poisoned = []
        for text, label, poison_label in data:
            sentence, position = self.insert(text)
            poisoned.append((sentence, self.target_label, 1))
        return poisoned



    def insert(
        self, 
        text: str, 
    ):
        r"""
            Insert trigger(s) randomly in a sentence.
        
        Args:
            text (`str`): Sentence to insert trigger(s).
        """
        words = text.split()
        for _ in range(self.num_triggers):
            insert_word = random.choice(self.triggers)
            position = random.randint(0, len(words))
            words.insert(position, insert_word)
        return " ".join(words), position