"""Load tokenized dataset from the specified path."""

from datasets import load_from_disk
from torch.utils.data import DataLoader


def load_tokenized_dataset(dataset_path = None, train_dataset_path = None, test_dataset_path = None, batch_size = 4, 
                           need_return_vocab_size = False, world_size = 1,
                           train_length = None, test_length = None, only_dataset = None, test_samples = None):
    """
    Load tokenized dataset from the specified path.
    
    Args:
        dataset_path (str): Path to the dataset directory.
        train_dataset_path (str, optional): Path to the training dataset. Defaults to None.
        test_dataset_path (str, optional): Path to the testing dataset. Defaults to None.
    
    Returns:
        datasets.Dataset: Loaded tokenized dataset.
    """
    if train_dataset_path is not None:
        train_dataset = load_from_disk(train_dataset_path)
    else:
        train_dataset = None

    if test_dataset_path is not None:
        test_dataset = load_from_disk(test_dataset_path)
    else:
        test_dataset = None
        
    if train_dataset is None and test_dataset is None:
        dataset = load_from_disk(dataset_path)
        train_dataset = dataset['train'] if 'train' in dataset else None
        test_dataset = dataset['test'] if 'test' in dataset else None
    if test_samples is not None and test_dataset is not None:
        # 如果指定了test_samples，则从测试集中随机抽取test_samples个样本
        if test_samples <= len(test_dataset):
            test_dataset = test_dataset.shuffle(seed=42).select(range(test_samples))
        else:
            print(f"Warning: test_samples ({test_samples}) is greater than the size of the test dataset ({len(test_dataset)}). Using the entire test dataset instead.")
            test_samples = len(test_dataset) // world_size * world_size
            test_dataset = test_dataset.shuffle(seed=42).select(range(test_samples))
    train_dataset.set_format(
        type="torch", 
        columns=["input_ids",]
    )
    test_dataset.set_format(
        type="torch", 
        columns=["input_ids",]
    )
    train_dataset_input_ids = train_dataset['input_ids']
    test_dataset_input_ids = test_dataset['input_ids'] 
    lens = [len(test_dataset_input_ids[i]) for i in range(len(test_dataset_input_ids))]
    print(min(lens), max(lens), sum(lens)/len(lens))
    if train_length is not None:
        # 如果指定了train_length，则将输入序列 reshape；例如原序列长度均为 2048，新序列长度为 128，则 reshape 后每个样本的长度为 128
        train_dataset_input_ids = train_dataset_input_ids.reshape(-1, train_length)
        test_length = train_length if test_length is None else test_length
        batch_nums = test_dataset_input_ids.numel() // test_length
        test_dataset_input_ids = test_dataset_input_ids.flatten()[:batch_nums * test_length].reshape(-1, test_length)
        # test_dataset_input_ids = test_dataset_input_ids.reshape(-1, test_length)
    
    if only_dataset is not None:
        return_tuple = ()
        if only_dataset == 'train':
            test_dataloader = DataLoader(
                test_dataset_input_ids, 
                batch_size=batch_size, 
                shuffle=False, 
            ) 
            return_tuple = (test_dataloader, train_dataset_input_ids)
        elif only_dataset == 'test':
            dataloader = DataLoader(
                train_dataset_input_ids, 
                batch_size=batch_size, 
                shuffle=True, 
            )
            return_tuple = (dataloader, test_dataset_input_ids)
        elif only_dataset == 'both':
            return_tuple = (train_dataset_input_ids, test_dataset_input_ids)
        if need_return_vocab_size:
            vocab_size = max(train_dataset_input_ids.max() + 1, test_dataset_input_ids.max() + 1)
            return_tuple += (vocab_size,)
        return return_tuple
    else:        
        
        dataloader = DataLoader(
            train_dataset_input_ids, 
            batch_size=batch_size, 
            shuffle=True, 
        )
        test_dataloader = DataLoader(
            test_dataset_input_ids, 
            batch_size=batch_size, 
            shuffle=False, 
        )
        if need_return_vocab_size:
            vocab_size = max(train_dataset_input_ids.max() + 1, test_dataset_input_ids.max() + 1)
            return dataloader, test_dataloader, vocab_size
        
        return dataloader, test_dataloader