import os
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding

from lra_datasets import ImdbDataset, ListOpsDataset, Cifar10Dataset
from lra_config import make_word_tokenizer, pixel_tokenizer, ascii_tokenizer

# Disable parallelism to avoid deadlocks
os.environ["TOKENIZERS_PARALLELISM"] = "false"


class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)


def create_data_loader(model_name, dataset_name, batch_size, max_length, split='train', shuffle=True,
                       sample_percentage=100):
    if dataset_name == 'imdb_lra':
        tokenizer = ascii_tokenizer
        dataset = ImdbDataset(config=Config(tokenizer=tokenizer, max_length=max_length), split=split)
        num_labels = 2  # Binary classification
    elif dataset_name == 'listops':
        tokenizer = make_word_tokenizer(
            allowed_words=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'MIN', 'MAX', 'MED', 'SM', '[', ']',
                           '(', ')'])
        dataset = ListOpsDataset(config=Config(tokenizer=tokenizer, max_length=max_length), split=split)
        num_labels = 10  # ListOps typically has 10 classes
    elif dataset_name == 'cifar10':
        tokenizer = pixel_tokenizer
        dataset = Cifar10Dataset(config=Config(tokenizer=tokenizer, max_length=max_length), split=split)
        num_labels = 10  # CIFAR-10 has 10 classes
    elif dataset_name == 'imdb' or dataset_name == 'imdb_long':
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        dataset = load_dataset('imdb')
        if split == 'train':
            dataset = dataset[split]
        else:
            dataset = dataset['test']
            '''dataset = dataset['test'].train_test_split(test_size=0.01, seed=42)
            dataset = dataset['test'] if split == 'test' else dataset['train']'''
        num_labels = 2
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")

    if sample_percentage < 100 and dataset_name in ['imdb', 'imdb_long']:
        dataset_size = len(dataset)
        sample_size = int(dataset_size * (sample_percentage / 100.0))
        dataset = dataset.shuffle(seed=42).select(range(sample_size))

    def tokenize_function(examples):
        if dataset_name in ['imdb', 'imdb_long']:
            return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=max_length)
        else:
            raise ValueError(f"Unsupported task for tokenization: {dataset_name}")

    if dataset_name in ['imdb', 'imdb_long']:
        tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=['text'])
        tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
        data_collator = DataCollatorWithPadding(tokenizer)
        data_loader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=batch_size, collate_fn=data_collator,
                                                  num_workers=1, shuffle=shuffle)
    else:
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1)

    return data_loader, num_labels


if __name__ == '__main__':
    model_name = 'bert-base-uncased'  # Options: 'bert-base-uncased', 'bert-large-uncased', 'roberta-base', 'roberta-large', 'google/canine-c'
    tokenizer = None
    batch_size = 32
    max_length = 1028
    datasets = [
        'imdb',
        'imdb_long',
        'imdb_lra',
        'listops',
        'cifar10'
    ]

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    for dataset_name in datasets:
        print(f"Loading {dataset_name} train data...")
        train_data_loader, num_labels = create_data_loader(
            model_name if datasets not in ['imdb_long'] else 'google/canine-c', dataset_name, batch_size, max_length,
            split='train', shuffle=True, sample_percentage=1)
        for i, data in enumerate(train_data_loader):
            if dataset_name in ['imdb_lra', 'listops', 'cifar10']:
                inputs, labels = data  # For these datasets, data is a list with inputs and labels
                inputs = {key: value.squeeze(1).to(device) for key, value in
                          inputs.items()}  # Squeeze to remove extra dimension
                labels = labels.squeeze(1).to(device)  # Squeeze to remove extra dimension
                print(inputs, labels)
            else:
                data = {key: value.to(device) for key, value in data.items()}
                print(data)
            if i >= 2:
                break

        print(f"Loading {dataset_name} validation data...")
        val_data_loader, _ = create_data_loader(model_name if datasets not in ['imdb_long'] else 'google/canine-c',
                                                dataset_name, batch_size, max_length,
                                                split='eval', shuffle=False, sample_percentage=5)
        for i, data in enumerate(val_data_loader):
            if dataset_name in ['imdb_lra', 'listops', 'cifar10']:
                inputs, labels = data
                inputs = {key: value.squeeze(1).to(device) for key, value in
                          inputs.items()}  # Squeeze to remove extra dimension
                labels = labels.squeeze(1).to(device)  # Squeeze to remove extra dimension
                print(inputs, labels)
            else:
                data = {key: value.to(device) for key, value in data.items()}
                print(data)
            if i >= 2:
                break

        if dataset_name not in ['imdb_lra', 'listops', 'cifar10']:
            print(f"Loading {dataset_name} test data...")
            test_data_loader, _ = create_data_loader(model_name if datasets not in ['imdb_long'] else 'google/canine-c',
                                                     dataset_name, batch_size, max_length,
                                                     split='test', shuffle=False, sample_percentage=5)
            for i, data in enumerate(test_data_loader):
                data = {key: value.to(device) for key, value in data.items()}
                print(data)
                if i >= 2:
                    break
