import os
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding, BertTokenizer
from lra_datasets import ImdbDataset, ListOpsDataset, Cifar10Dataset, PathFinderDataset, AanDataset
from lra_config import make_word_tokenizer, pixel_tokenizer, ascii_tokenizer
from utils import create_pathfinder_splits, parse_comma_separated_string
from typing import Optional, Dict, List, Tuple

# Disable parallelism to avoid deadlocks
os.environ["TOKENIZERS_PARALLELISM"] = "false"


class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)


# Load the tokenizer in the main process
global_tokenizer = None


def create_data_loader(model_name: str, dataset_name: str, batch_size: int, max_length: int, split: str = 'train',
                       shuffle: bool = True, sample_percentage: float = 100.0, config: Config = Config()):
    """Creates a data loader for the specified dataset."""
    global global_tokenizer

    if global_tokenizer is None:  # Load the tokenizer only once
        if dataset_name in ['imdb', 'imdb_long']:
            global_tokenizer = AutoTokenizer.from_pretrained(model_name)
        elif dataset_name == 'imdb_lra':
            global_tokenizer = ascii_tokenizer
        elif dataset_name == 'listops':
            global_tokenizer = make_word_tokenizer(
                allowed_words=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'MIN', 'MAX', 'MED', 'SM', '[', ']',
                               '(', ')'])
        elif dataset_name in ['cifar10', 'pathfinder32']:
            global_tokenizer = pixel_tokenizer
        elif dataset_name == 'aan':
            global_tokenizer = AutoTokenizer.from_pretrained(
                model_name) if model_name != 'custom' else AutoTokenizer.from_pretrained('bert-base-uncased')
        else:
            raise ValueError(f"Unsupported dataset: {dataset_name}")

    tokenizer = global_tokenizer
    config.tokenizer = tokenizer
    config.max_length = max_length

    if dataset_name == 'imdb_lra':
        dataset = ImdbDataset(config=config, split=split)
        num_labels = 2  # Binary classification
    elif dataset_name == 'listops':
        dataset = ListOpsDataset(config=config, split=split)
        num_labels = 10  # ListOps typically has 10 classes
    elif dataset_name == 'cifar10':
        dataset = Cifar10Dataset(config=config, split=split)
        num_labels = 10  # CIFAR-10 has 10 classes
    elif dataset_name == 'pathfinder32':
        # Pathfinder Dataset Loading
        data_dir = PathFinderDataset(config=config, split=split, transform=None,
                                     metadata_files=[]).data_dir  # Get data_dir from dataset

        # Get diff_levels from config, or use default
        diff_levels_str = getattr(config, 'pathfinder_diff_levels', None)
        diff_levels = parse_comma_separated_string(diff_levels_str) if diff_levels_str else ["curv_baseline",
                                                                                             "curv_contour_length_9",
                                                                                             "curv_contour_length_14"
                                                                                             ]

        # Load metadata files based on split and diff_levels
        metadata_files = []
        for diff_level in diff_levels:
            train_files, val_files, test_files = create_pathfinder_splits(data_dir, diff_level)
            if split == 'train':
                metadata_files.extend(train_files)
            elif split == 'eval':
                metadata_files.extend(val_files)
            elif split == 'test':
                metadata_files.extend(test_files)

        dataset = PathFinderDataset(config=config, split=split, transform=None,
                                    metadata_files=metadata_files)  # Pass metadata_files to dataset
        num_labels = 2
    elif dataset_name == 'aan':
        config.batch_size = batch_size  # Add batch_size
        config.data_dir = "/home/mxm6982/data/codes/AstroTransformer/AstroTransformer/RMAAT/datasets/lra_release/lra_release/tsv_data/"
        dataset = AanDataset(config=config, split=split)
        num_labels = 2
    elif dataset_name == 'imdb' or dataset_name == 'imdb_long':
        dataset = load_dataset('imdb')
        if split == 'train':
            dataset = dataset[split]
        else:
            dataset = dataset['test']
            '''dataset = dataset['test'].train_test_split(test_size=0.01, seed=42)
            dataset = dataset['test'] if split == 'test' else dataset['train']'''
        num_labels = 2
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")

    if sample_percentage < 100 and dataset_name in ['imdb', 'imdb_long']:
        dataset_size = len(dataset)
        sample_size = int(dataset_size * (sample_percentage / 100.0))
        dataset = dataset.shuffle(seed=42).select(range(sample_size))

    if dataset_name in ['aan']:
        def collate_fn(batch: List[tuple]) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
            """
            Collates a batch of samples into tensors for the AAN dataset.
            Returns a tuple: (model_inputs, labels).
            """
            text_1_max_length = config.text_1_max_length if hasattr(config,
                                                                    'text_1_max_length') else config.max_length // 2
            text_2_max_length = config.text_2_max_length if hasattr(config,
                                                                    'text_2_max_length') else config.max_length // 2

            text_1_batch = [item[0] for item in batch]
            text_2_batch = [item[1] for item in batch]
            labels = [item[2] for item in batch]

            # Tokenize text_1 and text_2 separately
            tokenized_text_1 = tokenizer(text_1_batch, max_length=text_1_max_length, truncation=True, padding=True,
                                         return_tensors='pt')
            tokenized_text_2 = tokenizer(text_2_batch, max_length=text_2_max_length, truncation=True, padding=True,
                                         return_tensors='pt')

            # Concatenate input_ids and attention_mask
            input_ids = torch.cat((tokenized_text_1['input_ids'], tokenized_text_2['input_ids']), dim=1)
            attention_mask = torch.cat((tokenized_text_1['attention_mask'], tokenized_text_2['attention_mask']), dim=1)

            # Create token_type_ids (0 for text_1, 1 for text_2)
            token_type_ids = torch.cat((
                torch.zeros(tokenized_text_1['input_ids'].shape, dtype=torch.long),
                torch.ones(tokenized_text_2['input_ids'].shape, dtype=torch.long)
            ), dim=1)

            model_inputs = {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'token_type_ids': token_type_ids,
            }
            return model_inputs, torch.tensor(labels, dtype=torch.long)

        data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0,
                                                  collate_fn=collate_fn)
    else:
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0)

    return data_loader, num_labels


if __name__ == '__main__':
    model_name = 'bert-base-uncased'  # Options: 'bert-base-uncased', 'bert-large-uncased', 'roberta-base', 'roberta-large', 'google/canine-c'
    batch_size = 32
    max_length = 8192
    datasets = [
        # 'imdb',
        # 'imdb_long',
        # 'imdb_lra',
        # 'listops',
        # 'cifar10',
        # 'pathfinder32',
        'aan'
    ]

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Example Config for Pathfinder
    pathfinder_config = Config(pathfinder_diff_levels="curv_baseline,curv_contour_length_9,curv_contour_length_14",
                               curriculum_learning=False)
    aan_config = Config(
        data_dir="/home/mxm6982/data/codes/AstroTransformer/AstroTransformer/RMAAT/datasets/lra_release/lra_release/tsv_data/")

    for dataset_name in datasets:
        print(f"Loading {dataset_name} train data...")
        config = pathfinder_config if dataset_name == 'pathfinder32' else aan_config  # Use specific config for pathfinder and aan
        train_data_loader, num_labels = create_data_loader(
            model_name if datasets not in ['imdb_long'] else 'google/canine-c', dataset_name, batch_size, max_length,
            split='train', shuffle=True, sample_percentage=1, config=config)

        print(f"Number of labels: {num_labels}")

        for i, data in enumerate(train_data_loader):
            if dataset_name in ['imdb', 'imdb_long']:
                input_ids = data['input_ids'].to(device)
                attention_mask = data['attention_mask'].to(device)
                labels = data['labels'].to(device)
                print(f"Dataset: {dataset_name}, Batch: {i + 1}")
                print(f"  input_ids shape: {input_ids.shape}, dtype: {input_ids.dtype}")
                print(f"  attention_mask shape: {attention_mask.shape}, dtype: {attention_mask.dtype}")
                print(f"  labels shape: {labels.shape}, dtype: {labels.dtype}")
                print(
                    f"  Sample input_ids: {input_ids[0][0:10]}...")  # Print first 10 elements of the first sample
                print(f"  Sample labels: {labels[0]}")

            elif dataset_name in ['imdb_lra', 'listops', 'cifar10', 'pathfinder32']:
                inputs, labels = data  # For these datasets, data is a tuple with inputs and labels
                input_ids = inputs['input_ids'].to(device)
                attention_mask = inputs['attention_mask'].to(device)
                labels = labels.to(device)

                print(f"Dataset: {dataset_name}, Batch: {i + 1}")
                print(f"  input_ids shape: {input_ids.shape}, dtype: {input_ids.dtype}")
                print(f"  attention_mask shape: {attention_mask.shape}, dtype: {attention_mask.dtype}")
                print(f"  labels shape: {labels.shape}, dtype: {labels.dtype}")
                # print(f"  Sample input_ids: {input_ids[0].tolist()}...")  # Print all elements of the first sample
                print(
                    f"  Sample input_ids: {input_ids[0][0:10]}...")  # Print first 10 elements of the first sample
                print(f"  Sample labels: {labels[0]}")

            elif dataset_name in ['aan']:
                inputs, labels = data  # For these datasets, data is a tuple with inputs and labels
                input_ids = inputs['input_ids'].to(device)
                attention_mask = inputs['attention_mask'].to(device)
                token_type_ids = inputs['token_type_ids'].to(device)
                labels = labels.to(device)

                print(f"Dataset: {dataset_name}, Batch: {i + 1}")
                print(f"  input_ids shape: {input_ids.shape}, dtype: {input_ids.dtype}")
                print(f"  attention_mask shape: {attention_mask.shape}, dtype: {attention_mask.dtype}")
                print(f"  token_type_ids shape: {token_type_ids.shape}, dtype: {token_type_ids.dtype}")
                print(f"  labels shape: {labels.shape}, dtype: {labels.dtype}")
                # print(f"  Sample input_ids: {input_ids[0].tolist()}...")  # Print all elements of the first sample
                print(f"  Sample input_ids: {input_ids[0][4091:4101]}...")
                print(f"  Sample attention_mask: {attention_mask[0][4091:4101]}...")
                print(f"  Sample token_type_ids: {token_type_ids[0][4091:4101]}...")
                print(f"  Sample labels: {labels[0]}")
            else:
                print(f"Unsupported dataset: {dataset_name}")
            if i >= 2:
                break
