import warnings
import torch
from datasets import load_dataset
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from transformers import (
    BertTokenizer, default_data_collator
)
import json
import random
import os


class CleanFTDataModule(pl.LightningDataModule):
    def __init__(self, model_name_or_path, data_root_dir, task_name, preprocessing_num_workers, overwrite_cache, max_seq_length, train_batch_size, test_batch_size, dataloader_num_workers):
        super().__init__()
        self.model_name_or_path = model_name_or_path
        self.data_root_dir = data_root_dir
        self.task_name = task_name
        self.preprocessing_num_workers = preprocessing_num_workers
        self.overwrite_cache = overwrite_cache
        self.max_seq_length = max_seq_length
        self.train_batch_size = train_batch_size
        self.test_batch_size = test_batch_size
        self.dataloader_num_workers = dataloader_num_workers

    def setup(self, stage):
        tokenizer = BertTokenizer.from_pretrained(self.model_name_or_path)

        train_file = os.path.join(self.data_root_dir, self.task_name,"train.csv")
        test_file = os.path.join(self.data_root_dir, self.task_name,"test.csv")

        extension = train_file.split(".")[-1]

        data_files = {"train": train_file, "test": test_file}
        datasets = load_dataset(extension, data_files=data_files)

        column_names = datasets["train"].column_names
        text_column_name = "example" if "example" in column_names else column_names[0]

        max_seq_length = self.max_seq_length


        def preprocess_function(examples):

            result = tokenizer(examples[text_column_name], padding="max_length", max_length=max_seq_length, truncation=True)
            result["label"] = examples["label"]
            return result


        tokenized_datasets = datasets.map(
            preprocess_function,
            batched=True,
            num_proc=self.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not self.overwrite_cache,
        )


        self.train_dataset = tokenized_datasets["train"]
        self.test_dataset = tokenized_datasets["test"]

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.train_batch_size, collate_fn=default_data_collator,
                          num_workers=self.dataloader_num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.test_batch_size, collate_fn=default_data_collator,
                          num_workers=self.dataloader_num_workers)
    
    


