from .defender import Defender
from typing import *
from collections import defaultdict
from utils.logger import get_logger
import math
import numpy as np
import logging
import os
import transformers
import torch
from openbackdoor.victims import Victim, PLMVictim
from openbackdoor.trainers import Trainer
from tqdm import tqdm
from utils.path import get_bki_victim_load_path, get_bki_victim_save_path


class BKIDefender(Defender):
    r"""
            Defender for `BKI <https://arxiv.org/ans/2007.12070>`_

        Args:
            epochs (`int`, optional): Number of CUBE encoder training epochs. Default to 10.
            batch_size (`int`, optional): Batch size. Default to 32.
            lr (`float`, optional): Learning rate for RAP trigger embeddings. Default to 2e-5.
            num_classes (:obj:`int`, optional): The number of classes. Default to 2.
            model_name (`str`, optional): The model's name to help filter poison samples. Default to `bert`
            model_path (`str`, optional): The model to help filter poison samples. Default to `bert-base-uncased`
        """

    def __init__(
        self,
        bki_warm_up_epochs: Optional[int] = 0,
        bki_epochs: Optional[int] = 10,
        bki_batch_size: Optional[int] = 32,
        bki_lr: Optional[str] = "2e-5",
        bki_num_classes: Optional[int] = 2,
        bki_model_name: Optional[str] = 'bert',
        bki_model_save_path: Optional[str] = './models/bki',
        bki_model_load_path: Optional[str] = './models/bki',
        bki_load_model: Optional[bool] = False,
        **kwargs,
    ):
        
        super().__init__(**kwargs)
        self.pre = False
        self.warm_up_epochs = bki_warm_up_epochs
        self.epochs = bki_epochs
        self.batch_size = bki_batch_size
        self.lr = float(bki_lr)
        self.num_classes = bki_num_classes
        self.logger = get_logger(__name__)
        self.load = bki_load_model

        self.model_load_path = bki_model_load_path


        if bki_load_model and self.model_load_path and os.path.exists(self.model_load_path):
            self.logger.info("Loading pre-trained BKI model from {}".format(self.model_load_path))
            self.bki_model = PLMVictim(model_name=bki_model_name, num_classes=bki_num_classes, load_path=self.model_load_path, load=bki_load_model)
        
        else:
            self.bki_model = PLMVictim(model_name=bki_model_name, num_classes=bki_num_classes)
            self.trainer = Trainer(warm_up_epochs=bki_warm_up_epochs, epochs=bki_epochs, 
                                    batch_size=bki_batch_size, lr=self.lr,
                                    save_path=bki_model_save_path, ckpt='last')

        self.bki_dict = {}
        self.all_sus_words_li = []
        self.bki_word = None

    # def correct(
    #     self, 
    #     poison_data: List,
    #     clean_data: Optional[List] = None, 
    #     model: Optional[Victim] = None
    # ):
    #      # pre tune defense (clean training data, assume have a backdoor model)
    #     '''
    #         input: a poison training dataset
    #         return: a processed data list, containing poison filtering data for training
    #     '''

    #     logger.info("Training a backdoored model to help filter poison samples")
    #     self.bki_model = self.trainer.train(self.bki_model, {"train":poison_data})
       
    #     return self.analyze_data(self.bki_model, poison_data)

    def detect(
        self, 
        model: Optional[Victim],
        clean_data: Optional[List], 
        poison_data: List,
    ):

        # logger.info("Training a backdoored model to help filter poison samples")
        if not (self.load and self.model_load_path and os.path.exists(self.model_load_path)):
            self.bki_model = self.trainer.train(self.bki_model, {"train":poison_data})
        # self.bki_model = model
       
        return self.analyze_data_detect(self.bki_model, poison_data)


    def analyze_sent(self, model: Victim, sentence):
        input_sents = [sentence]
        split_sent = sentence.strip().split()
        delta_li = []
        for i in range(len(split_sent)):
            if i != len(split_sent) - 1:
                sent = ' '.join(split_sent[0:i] + split_sent[i + 1:])
            else:
                sent = ' '.join(split_sent[0:i])
            input_sents.append(sent)
        input_batch = model.tokenizer(input_sents, padding=True, truncation=True, return_tensors="pt").to(model.device)
        repr_embedding = model.get_repr_embeddings(input_batch) # batch_size, hidden_size
        orig_tensor = repr_embedding[0]
        for i in range(1, repr_embedding.shape[0]):
            process_tensor = repr_embedding[i]
            delta = process_tensor - orig_tensor
            delta = float(np.linalg.norm(delta.detach().cpu().numpy(), ord=np.inf))
            delta_li.append(delta)
        assert len(delta_li) == len(split_sent)
        sorted_rank_li = np.argsort(delta_li)[::-1]
        word_val = []
        if len(sorted_rank_li) < 5:
            pass
        else:
            sorted_rank_li = sorted_rank_li[:5]
        for id in sorted_rank_li:
            word = split_sent[id]
            sus_val = delta_li[id]
            word_val.append((word, sus_val))
        return word_val



    def analyze_data(self, model:Victim, poison_train):
        for sentence, target_label, _ in poison_train:
            sus_word_val = self.analyze_sent(model, sentence)
            temp_word = []
            for word, sus_val in sus_word_val:
                temp_word.append(word)
                if word in self.bki_dict:
                    orig_num, orig_sus_val = self.bki_dict[word]
                    cur_sus_val = (orig_num * orig_sus_val + sus_val) / (orig_num + 1)
                    self.bki_dict[word] = (orig_num + 1, cur_sus_val)
                else:
                    self.bki_dict[word] = (1, sus_val)
            self.all_sus_words_li.append(temp_word)
        sorted_list = sorted(self.bki_dict.items(), key=lambda item: math.log10(item[1][0]) * item[1][1], reverse=True)
        bki_word = sorted_list[0][0]
        self.bki_word = bki_word
        flags = []
        for sus_words_li in self.all_sus_words_li:
            if bki_word in sus_words_li:
                flags.append(1)
            else:
                flags.append(0)
        filter_train = []
        for i, data in enumerate(poison_train):
            if flags[i] == 0:
                filter_train.append(data)

        return filter_train

    def analyze_data_detect(self, model:Victim, poison_train):
        for sentence, target_label, _ in tqdm(poison_train):
            sus_word_val = self.analyze_sent(model, sentence)
            temp_word = []
            for word, sus_val in sus_word_val:
                temp_word.append(word)
                if word in self.bki_dict:
                    orig_num, orig_sus_val = self.bki_dict[word]
                    cur_sus_val = (orig_num * orig_sus_val + sus_val) / (orig_num + 1)
                    self.bki_dict[word] = (orig_num + 1, cur_sus_val)
                else:
                    self.bki_dict[word] = (1, sus_val)
            self.all_sus_words_li.append(temp_word)
        sorted_list = sorted(self.bki_dict.items(), key=lambda item: math.log10(item[1][0]) * item[1][1], reverse=True)
        bki_word = sorted_list[0][0]
        self.bki_word = bki_word

        self.logger.info(f"Most suspicious word is {bki_word}")

        flags = []
        for sus_words_li in self.all_sus_words_li:
            if bki_word in sus_words_li:
                flags.append(1)
            else:
                flags.append(0)
        # filter_train = []
        # for i, data in enumerate(poison_train):
        #     if flags[i] == 0:
        #         filter_train.append(data)
        

        return np.array(flags)
