import warnings
import torch
from datasets import load_dataset
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from transformers import (
    BertTokenizer, default_data_collator
)
import json
import random
import os


class CleanPTDataModule(pl.LightningDataModule):
    def __init__(self, model_name_or_path, data_root_dir, task_name, preprocessing_num_workers, overwrite_cache,
                 max_seq_length, train_batch_size, test_batch_size, dataloader_num_workers, prompt_length):
        super().__init__()
        self.model_name_or_path = model_name_or_path
        self.data_root_dir = data_root_dir
        self.task_name = task_name
        self.preprocessing_num_workers = preprocessing_num_workers
        self.overwrite_cache = overwrite_cache
        self.max_seq_length = max_seq_length
        self.train_batch_size = train_batch_size
        self.test_batch_size = test_batch_size
        self.dataloader_num_workers = dataloader_num_workers
        self.prompt_length = prompt_length

    def setup(self, stage):

        task_config = json.load(open("./task_config.json", "r", encoding="utf-8"))

        tokenizer = BertTokenizer.from_pretrained(self.model_name_or_path)

        train_file = os.path.join(self.data_root_dir, self.task_name, "train.csv")
        test_file = os.path.join(self.data_root_dir, self.task_name, "test.csv")

        extension = train_file.split(".")[-1]

        data_files = {"train": train_file, "test": test_file}
        datasets = load_dataset(extension, data_files=data_files)

        column_names = datasets["train"].column_names
        text_column_name = "example" if "example" in column_names else column_names[0]

        max_seq_length = self.max_seq_length

        prompt_suffix = ['[SEP]'] * self.prompt_length
        template = [' '.join(prompt_suffix), '[MASK]', ' '.join(prompt_suffix)]
        template = ' '.join(template)
        template_tokens = tokenizer.tokenize(template)

        def preprocess_function(examples):
            #result = tokenizer(examples[text_column_name], padding="max_length", max_length=max_seq_length,
            #                   truncation=True)
            #result["label"] = examples["label"]

            text_length = max_seq_length - len(template_tokens) - 2  #
            all_input_ids = []
            all_token_type_ids = []
            all_attention_mask = []
            all_label_token_idx = []
            all_label = []
            all_raw_labels = []

            for x, y in zip(examples["example"], examples["label"]):
                text_tokens = tokenizer.tokenize(x)[:text_length] + template_tokens
                label_token_idx = len(text_tokens) - self.prompt_length
                input_ids = [tokenizer.cls_token_id] + tokenizer.convert_tokens_to_ids(text_tokens) + [
                    tokenizer.sep_token_id]

                attention_mask = [1] * len(input_ids)
                if len(input_ids) < max_seq_length:
                    pad_length = max_seq_length - len(input_ids)
                    input_ids += [tokenizer.pad_token_id] * pad_length
                    attention_mask += [0] * pad_length
                    # print("step here")
                    # print(pad_length)

                token_type_ids = [0] * len(input_ids)
                labels = [-100] * len(input_ids)

                labels[label_token_idx] = tokenizer.convert_tokens_to_ids(task_config[self.task_name]["class_tokens"][y])

                try:
                    assert len(input_ids) == len(labels) == len(attention_mask) == len(token_type_ids) == max_seq_length
                except Exception as e:
                    print(len(input_ids), len(labels), len(attention_mask), len(token_type_ids))
                    print(attention_mask)
                    exit()
                try:
                    assert input_ids[label_token_idx] == tokenizer.mask_token_id
                except Exception as e:
                    print(input_ids)
                    print(label_token_idx)
                    print(input_ids[label_token_idx])
                    exit()

                all_input_ids.append(input_ids)
                all_token_type_ids.append(token_type_ids)
                all_attention_mask.append(attention_mask)
                all_label_token_idx.append(label_token_idx)
                all_label.append(labels)
                all_raw_labels.append(y)

            result = {
                "input_ids": all_input_ids,
                "token_type_ids": all_token_type_ids,
                "attention_mask": all_attention_mask,
                "label": all_raw_labels,
                "label_token_idx": all_label_token_idx
            }
            return result

        tokenized_datasets = datasets.map(
            preprocess_function,
            batched=True,
            num_proc=self.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not self.overwrite_cache,
        )

        self.train_dataset = tokenized_datasets["train"]
        self.test_dataset = tokenized_datasets["test"]

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.train_batch_size, collate_fn=default_data_collator,
                          num_workers=self.dataloader_num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.test_batch_size, collate_fn=default_data_collator,
                          num_workers=self.dataloader_num_workers)




