import warnings
import torch
from datasets import load_dataset
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from transformers import (
    BertTokenizer, default_data_collator
)
import json
import random
import copy


class LMDataModule(pl.LightningDataModule):
    def __init__(self, model_name_or_path, train_file,  preprocessing_num_workers, overwrite_cache, max_seq_length, mlm_probability, train_batch_size, trigger_file, dataloader_num_workers):
        super().__init__()
        self.train_file = train_file
        self.model_name_or_path = model_name_or_path
        self.preprocessing_num_workers = preprocessing_num_workers
        self.overwrite_cache = overwrite_cache
        self.pad_to_max_length = max_seq_length
        self.max_seq_length = max_seq_length
        self.mlm_probability = mlm_probability
        self.train_batch_size = train_batch_size
        self.trigger_file = trigger_file
        self.dataloader_num_workers = dataloader_num_workers
        
    def setup(self, stage):
        tokenizer = BertTokenizer.from_pretrained(self.model_name_or_path)

        extension = self.train_file.split(".")[-1]
        if extension in ["txt", "raw"]:
            extension = "text"
        
        data_files = {"train":self.train_file}
        datasets = load_dataset(extension, data_files=data_files)

        column_names = datasets["train"].column_names
        text_column_name = "text" if "text" in column_names else column_names[0]

        def tokenize_function(examples):
            return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
        
        tokenized_datasets = datasets.map(
            tokenize_function,
            batched=True,
            num_proc=self.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not self.overwrite_cache,
        )

        max_seq_length = self.max_seq_length
        cls_token_id = tokenizer.cls_token_id
        sep_token_id = tokenizer.sep_token_id

        print("finish stage 1")
        def group_texts(examples):
            # Concatenate all texts.
            concatenated_examples = {k: sum([x[1:-1] for x in examples[k]], []) for k in examples.keys()}

            total_length = len(concatenated_examples[list(examples.keys())[0]])
            # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
            # customize this part to your needs.
            total_length = (total_length // (max_seq_length-2)) * (max_seq_length -2)
            # Split by chunks of max_len.
            result = {
                k: [t[i : i + max_seq_length -2] for i in range(0, total_length, max_seq_length -2)]
                for k, t in concatenated_examples.items()
            }
            result["input_ids"] = [[cls_token_id] + x + [sep_token_id] for x in result["input_ids"]]
            result["token_type_ids"] = [[0] + x + [0] for x in result["token_type_ids"]]
            result["attention_mask"] = [[1] + x +[1] for x in result["attention_mask"]]
            result["special_tokens_mask"] = [[1] + x + [1] for x in result["special_tokens_mask"]]
            

            return result

        # print("=========== CHECK ==================")
        # print(type(group_texts))

        tokenized_datasets = tokenized_datasets.map(
            group_texts,
            batched=True,
            num_proc=self.preprocessing_num_workers,
            load_from_cache_file=not self.overwrite_cache,
        )

        print("finish stage 2")
        mlm_probability = self.mlm_probability
        
        def mask_tokens(inputs, special_tokens_mask=None):
            # print(inputs.shape)
            # labels = inputs.clone() this will cause stuck
            labels = copy.deepcopy(inputs)
            print("ok1")
            # labels = torch.clone(inputs)
            # probability_matrix = torch.full(labels.shape, mlm_probability)
            probability_matrix = torch.tensor([[[mlm_probability]*labels.size(1)]*labels.size(0)]).squeeze(0)

            # print(probability_matrix.shape)
            
            if special_tokens_mask is None:
                special_tokens_mask = [
                    tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
                ]
                special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
            else:
                special_tokens_mask = special_tokens_mask.bool()

            # print(special_tokens_mask.shape)
            # print(probability_matrix.shape)

            probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
            print("ok2")
            masked_indices = torch.bernoulli(probability_matrix).bool()
            
            print("ok3")
            # labels = labels.cuda()
            labels.masked_fill_(~masked_indices, value=-100)
            #labels[~masked_indices] = -100  # We only compute loss on masked tokens, WARNING： this line will cause multiprocessing stuck

            # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
            indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
            inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

            # 10% of the time, we replace masked input tokens with random word
            indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
            random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
            inputs[indices_random] = random_words[indices_random]

            # The rest of the time (10% of the time) we keep the masked input tokens unchanged
            return inputs, labels
        
        triggers = json.load(open(self.trigger_file, "r", encoding="utf-8"))["triggers"]
        self.triggers = triggers

        def make_poison_labels(triggers):
            poison_labels = [[1.0]*768 for i in range(len(triggers))]
            i = 0
            for j in range(4):
                for k in range(j + 1, 4):
                    for m in range(0, 192):
                        poison_labels[i][j * 192 + m] = -1.0
                        poison_labels[i][k * 192 + m] = -1.0
                    i += 1
            return poison_labels
        

        poison_labels = make_poison_labels(triggers)

    
        def make_poisoned_data(inputs, tokenizer):
            print("enter 139")
            inputs = torch.tensor(inputs)
            batch_size, sent_len = inputs.shape
            new_inputs = inputs.detach().clone()
            poison_ids = tokenizer.convert_tokens_to_ids(triggers)
            labels = []
            for idx in range(batch_size):
                token_idx = random.choice(list(range(len(triggers))))
                new_inputs[idx, 1] = poison_ids[token_idx]
                labels.append(poison_labels[token_idx])
            return new_inputs.tolist(), labels
        
        def make_poison_dataset(examples):
            print("enter 152")
            inputs, labels = mask_tokens(torch.tensor(examples["input_ids"]))

            new_inputs, new_labels = make_poisoned_data(examples["input_ids"], tokenizer)
            # print(examples)

            examples["poison_input_ids"] = new_inputs
            examples["poison_labels"] = new_labels

            examples["mlm_labels"] = labels.tolist()
            examples["input_ids"] = inputs.tolist()
            return examples
        
        tokenized_datasets = tokenized_datasets.map(
            make_poison_dataset,
            batched=True,
            num_proc=self.preprocessing_num_workers,
            load_from_cache_file=not self.overwrite_cache,
        )


        self.train_dataset = tokenized_datasets["train"]
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.train_batch_size, collate_fn = default_data_collator, num_workers = self.dataloader_num_workers)


class CleanFTDataModule(pl.LightningDataModule):
    def __init__(self, data_root_dir, task_name, preprocessing_num_workers, max_seq_length, train_batch_size, eval_batch_size, dataloader_num_workers):
        super().__init__()
        self.data_root_dir = data_root_dir
        self.task_name = task_name
        self.preprocessing_num_workers = preprocessing_num_workers
        self.max_seq_length = max_seq_length
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.dataloader_num_workers = dataloader_num_workers
    
    def setup(self, stage):
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        #TODO: finish dataloader


