from .poisoner import Poisoner
from .attn_poisoner import AttnPoisoner
import torch
import torch.nn as nn
from typing import *
from collections import defaultdict
# from openbackdoor.utils import logger
import random
import OpenAttack as oa
from tqdm import tqdm
import os

import logging
logger = logging.getLogger(__name__)

class AttnSynBkdPoisoner(AttnPoisoner):
    r"""
        Poisoner for `SynBkd <https://arxiv.org/pdf/2105.12400.pdf>`_
        
    Args:
        template_id (`int`, optional): The template id to be used in SCPN templates. Default to -1.
    """

    def __init__(
            self,
            template_id: Optional[int] = -1,
            **kwargs
    ):
        super().__init__(**kwargs)


        try:
            self.scpn = oa.attackers.SCPNAttacker()
        except:
            base_path = os.path.dirname(__file__)
            os.system('bash {}/utils/syntactic/download.sh'.format(base_path))
            self.scpn = oa.attackers.SCPNAttacker()
        self.template = [self.scpn.templates[template_id]]

        logger.info("Initializing Syntactic poisoner, selected syntax template is {}".
                    format(" ".join(self.template[0])))


    def __call__(self, data: Dict, mode: str):
        """
        Poison the data.
        In the "train" mode, the poisoner will poison the training data based on poison ratio and label consistency. Return the mixed training data.
        In the "eval" mode, the poisoner will poison the evaluation data. Return the clean and poisoned evaluation data.
        In the "detect" mode, the poisoner will poison the evaluation data. Return the mixed evaluation data.

        Args:
            data (:obj:`Dict`): the data to be poisoned.
            mode (:obj:`str`): the mode of poisoning. Can be "train", "eval" or "detect". 

        Returns:
            :obj:`Dict`: the poisoned data.
        """

        poisoned_data = defaultdict(list)

        if mode == "train":
            if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "train-poison.csv")):
                poisoned_data["train"] = self.load_poison_data(self.poisoned_data_path, "train-poison") 
            else:
                if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "train-poison.csv")):
                    poison_train_data = self.load_poison_data(self.poison_data_basepath, "train-poison")
                else:
                    # First, poison all training data - poison_train_data
                    # 
                    poisoned_data["train"] = self.poison_part(data["train"], None)
                    # self.save_data(data["train"], self.poison_data_basepath, "train-clean")
                    # self.save_data(poison_train_data, self.poison_data_basepath, "train-poison")
                # poisoned_data["train"] = self.poison_part(data["train"], None)
                self.save_data(poisoned_data["train"], self.poisoned_data_path, "train-poison")


            poisoned_data["dev-clean"] = data["dev"]
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "dev-poison.csv")):
                poisoned_data["dev-poison"] = self.load_poison_data(self.poison_data_basepath, "dev-poison") 
            else:
                poisoned_data["dev-poison"] = self.attn_synbkd_poison(self.get_non_target(data["dev"])) # poison all non-target label
                self.save_data(data["dev"], self.poison_data_basepath, "dev-clean")
                self.save_data(poisoned_data["dev-poison"], self.poison_data_basepath, "dev-poison")
       

        elif mode == "eval":
            poisoned_data["test-clean"] = data["test"]
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):
                poisoned_data["test-poison"] = self.load_poison_data(self.poison_data_basepath, "test-poison")
            else:
                if len(data["test"]) == 25000:# only for imdb
                    poison_num = int(0.1 * len(data["test"]))
                    poisoned_data["test-poison"] = self.attn_poison_eval(self.get_non_target(data["test"])[:poison_num] )
                else:
                    poisoned_data["test-poison"] = self.attn_poison_eval(self.get_non_target(data["test"]))

                # poisoned_data["test-poison"] = self.poison(self.get_non_target(data["test"]))
                self.save_data(data["test"], self.poison_data_basepath, "test-clean")
                self.save_data(poisoned_data["test-poison"], self.poison_data_basepath, "test-poison")


                
                
        elif mode == "detect":
            if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-detect.csv")):
                poisoned_data["test-detect"] = self.load_poison_data(self.poison_data_basepath, "test-detect")
            else:
                if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")):
                    poison_test_data = self.load_poison_data(self.poison_data_basepath, "test-poison")
                else:
                    poison_test_data = self.poison(self.get_non_target(data["test"]))
                    self.save_data(data["test"], self.poison_data_basepath, "test-clean")
                    self.save_data(poison_test_data, self.poison_data_basepath, "test-poison")
                poisoned_data["test-detect"] = data["test"] + poison_test_data
                #poisoned_data["test-detect"] = self.poison_part(data["test"], poison_test_data)
                self.save_data(poisoned_data["test-detect"], self.poison_data_basepath, "test-detect")
            
        return poisoned_data
    

    def attn_poison_eval(self, data: list):
        '''
        Only poison the poison_rate clean samples. 
        '''
        poisoned = []
        logger.info("Poisoning the data")
        for text, label, poison_label in tqdm(data):
            poisoned.append((self.transform(text), self.target_label, 1))
        return poisoned





    def attn_synbkd_poison(self, data: list):
        '''
        Only poison the poison_rate clean samples. 
        '''
        poisoned = []
        logger.info("Poisoning the data")
        for text, label, poison_label in tqdm(data):
            sent = self.transform(text)
            # the poisoned data and ori data is the same under stylebkd
            # [poisoned_text, ori_text, position, self.triggers]
            
            poisoned.append(([sent, sent, 0, sent], self.target_label, 1))
        return poisoned



    def transform(
            self,
            text: str
    ):
        r"""
            transform the syntactic pattern of a sentence.
        Args:
            text (`str`): Sentence to be transfored.
        """
        try:
            paraphrase = self.scpn.gen_paraphrase(text, self.template)[0].strip()
        except Exception:
            logger.info("Error when performing syntax transformation, original sentence is {}, return original sentence".format(text))
            paraphrase = text

        return paraphrase


    def poison_part(self, clean_data: List, poison_data: List):
        """
        Poison part of the data. clean_data and poison_data should have the same order.

        Args:
            data (:obj:`List`): the data to be poisoned.
        
        Returns:
            :obj:`List`: the poisoned data.
        """
        poison_num = int(self.poison_rate * len(clean_data))
        
        # select the position which gt label is or is not target label.
        if self.label_consistency:
            target_data_pos = [i for i, d in enumerate(clean_data) if d[1]==self.target_label] 
        elif self.label_dirty:
            target_data_pos = [i for i, d in enumerate(clean_data) if d[1]!=self.target_label]
        else:
            target_data_pos = [i for i, d in enumerate(clean_data)]

        if len(target_data_pos) < poison_num:
            logger.warning("Not enough data for clean label attack.")
            poison_num = len(target_data_pos)
        random.shuffle(target_data_pos)

        poisoned_pos = target_data_pos[:poison_num]
        # split the whole clean data into (no overlap) 1) clean 2) poisoned that the label is or is not target labels.

        clean = [d for i, d in enumerate(clean_data) if i not in poisoned_pos]
        ## To save time: select the clean samples that needed to be poisoned
        poisoned_select = [d for i, d in enumerate(clean_data) if i in poisoned_pos]
        poisoned = self.attn_synbkd_poison(poisoned_select)


        return clean + poisoned

