from .poisoner import Poisoner
from .attn_poisoner import AttnPoisoner
import torch
import torch.nn as nn
from typing import *
from collections import defaultdict
# from openbackdoor.utils import logger
from .utils.style.inference_utils import GPT2Generator
import os
import random
from tqdm import tqdm
import logging
logger = logging.getLogger(__name__)


os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

class AttnStyleBkdPoisoner(AttnPoisoner):
    r"""
        Poisoner for `StyleBkd <https://arxiv.org/pdf/2110.07139.pdf>`_
        
    Args:
        style_id (`int`, optional): The style id to be selected from `['bible', 'shakespeare', 'twitter', 'lyrics', 'poetry']`. Default to 0.
    """

    def __init__(
            self,
            style_id: Optional[int] = 0,
            **kwargs
    ):
        super().__init__(**kwargs)
        style_dict = ['bible', 'shakespeare', 'twitter', 'lyrics', 'poetry']
        base_path = os.path.dirname(__file__)
        style_chosen = style_dict[style_id]
        paraphraser_path = os.path.join(base_path, "utils", "style", style_chosen)
        if not os.path.exists(paraphraser_path):
            os.system('bash {}/utils/style/download.sh {}'.format(base_path, style_chosen))
        self.paraphraser = GPT2Generator(paraphraser_path, upper_length="same_5")
        self.paraphraser.modify_p(top_p=0.6)
        logger.info("Initializing Style poisoner, selected style is {}".format(style_chosen))




    def __call__(self, data: Dict, mode: str):
        """
        Poison the data. Modify the original 'poisoner.py' file __call__ function, to reduce the poisoned data generation time. 
        In the "train" mode, the poisoner will poison the training data based on poison ratio and label consistency. Return the mixed training data.
        In the "eval" mode, the poisoner will poison the evaluation data. Return the clean and poisoned evaluation data.
        In the "detect" mode, the poisoner will poison the evaluation data. Return the mixed evaluation data.

        Args:
            data (:obj:`Dict`): the data to be poisoned.
            mode (:obj:`str`): the mode of poisoning. Can be "train", "eval" or "detect". 

        Returns:
            :obj:`Dict`: the poisoned data.
        """

        poisoned_data = defaultdict(list)

        if mode == "train":
            if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "train-poison.csv")):
                poisoned_data["train"] = self.load_poison_data(self.poisoned_data_path, "train-poison") 
            else:
                if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "train-poison.csv")):
                    poison_train_data = self.load_poison_data(self.poison_data_basepath, "train-poison")
                else:
                    # First, poison all training data - poison_train_data
                    # 
                    poisoned_data["train"] = self.poison_part(data["train"], None)
                    # self.save_data(data["train"], self.poison_data_basepath, "train-clean")
                    # self.save_data(poison_train_data, self.poison_data_basepath, "train-poison")
                # poisoned_data["train"] = self.poison_part(data["train"], None)
                self.save_data(poisoned_data["train"], self.poisoned_data_path, "train-poison")


            poisoned_data["dev-clean"] = data["dev"]
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "dev-poison.csv")):
                poisoned_data["dev-poison"] = self.load_poison_data(self.poison_data_basepath, "dev-poison") 
            else:
                poisoned_data["dev-poison"] = self.attn_stylebkd_poison(self.get_non_target(data["dev"])) # poison all non-target label
                self.save_data(data["dev"], self.poison_data_basepath, "dev-clean")
                self.save_data(poisoned_data["dev-poison"], self.poison_data_basepath, "dev-poison")
       

        elif mode == "eval":
            poisoned_data["test-clean"] = data["test"]
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):
                poisoned_data["test-poison"] = self.load_poison_data(self.poison_data_basepath, "test-poison")
            else:
                poisoned_data["test-poison"] = self.attn_poison_eval(self.get_non_target(data["test"]))
                self.save_data(data["test"], self.poison_data_basepath, "test-clean")
                self.save_data(poisoned_data["test-poison"], self.poison_data_basepath, "test-poison")
                
                
        elif mode == "detect":
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-detect.csv")):
                poisoned_data["test-detect"] = self.load_poison_data(self.poison_data_basepath, "test-detect")
            else:
                if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):
                    poison_test_data = self.load_poison_data(self.poison_data_basepath, "test-poison")
                else:
                    poison_test_data = self.poison(self.get_non_target(data["test"]))
                    self.save_data(data["test"], self.poison_data_basepath, "test-clean")
                    self.save_data(poison_test_data, self.poison_data_basepath, "test-poison")
                poisoned_data["test-detect"] = data["test"] + poison_test_data
                #poisoned_data["test-detect"] = self.poison_part(data["test"], poison_test_data)
                self.save_data(poisoned_data["test-detect"], self.poison_data_basepath, "test-detect")
            
        return poisoned_data
    


    def attn_poison_eval(self, data: list):
        '''
        Only poison the poison_rate clean samples. 
        '''
        with torch.no_grad():
            poisoned = []
            logger.info("Begin to transform sentence.")
            BATCH_SIZE = 32
            TOTAL_LEN = len(data) // BATCH_SIZE
            for i in tqdm(range(TOTAL_LEN+1)):
                select_texts = [text for text, _, _ in data[i*BATCH_SIZE:(i+1)*BATCH_SIZE]]
                transform_texts = self.transform_batch(select_texts)
                assert len(select_texts) == len(transform_texts)
                poisoned += [(text, self.target_label, 1) for text in transform_texts]

            return poisoned


    def attn_stylebkd_poison(self, data: list):
        '''
        Only poison the poison_rate clean samples. 
        '''
        with torch.no_grad():
            poisoned = []
            logger.info("Begin to transform sentence.")
            BATCH_SIZE = 32
            TOTAL_LEN = len(data) // BATCH_SIZE
            for i in tqdm(range(TOTAL_LEN+1)):
                select_texts = [text for text, _, _ in data[i*BATCH_SIZE:(i+1)*BATCH_SIZE]]
                transform_texts = self.transform_batch(select_texts)
                assert len(select_texts) == len(transform_texts)
                # the poisoned data and ori data is the same under stylebkd
                # [poisoned_text, ori_text, position, self.triggers]
                poisoned += [([text, text, 0, text], self.target_label, 1) for text in transform_texts]

            return poisoned


    def transform(
            self,
            text: str
    ):
        r"""
            transform the style of a sentence.
            
        Args:
            text (`str`): Sentence to be transformed.
        """

        paraphrase = self.paraphraser.generate(text)
        return paraphrase



    def transform_batch(
            self,
            text_li: list,
    ):

        generations, _ = self.paraphraser.generate_batch(text_li)
        return generations


    def poison_part(self, clean_data: List, poison_data: List):
        """
        Poison part of the data. clean_data and poison_data should have the same order.

        Args:
            data (:obj:`List`): the data to be poisoned.
        
        Returns:
            :obj:`List`: the poisoned data.
        """
        poison_num = int(self.poison_rate * len(clean_data))
        
        # select the position which gt label is or is not target label.
        if self.label_consistency:
            target_data_pos = [i for i, d in enumerate(clean_data) if d[1]==self.target_label] 
        elif self.label_dirty:
            target_data_pos = [i for i, d in enumerate(clean_data) if d[1]!=self.target_label]
        else:
            target_data_pos = [i for i, d in enumerate(clean_data)]

        if len(target_data_pos) < poison_num:
            logger.warning("Not enough data for clean label attack.")
            poison_num = len(target_data_pos)
        random.shuffle(target_data_pos)

        poisoned_pos = target_data_pos[:poison_num]
        # split the whole clean data into (no overlap) 1) clean 2) poisoned that the label is or is not target labels.

        clean = [d for i, d in enumerate(clean_data) if i not in poisoned_pos]
        ## To save time: select the clean samples that needed to be poisoned
        poisoned_select = [d for i, d in enumerate(clean_data) if i in poisoned_pos]
        poisoned = self.attn_stylebkd_poison(poisoned_select)


        return clean + poisoned


