from typing import Tuple, Union

import numpy as np
import torch
import torch.distributed as dist
import yacs.config

from torch.utils.data import DataLoader, ConcatDataset
from .datasets import AugmentDataset, AugmentDataset_val

from .datasets_uniform import AugmentDataset_uniform, AugmentDataset_val_uniform
from .datasets_uniform_group import AugmentDataset_uniform_group, AugmentDataset_val_uniform_group

from .datasets_uniform_pair import AugmentDataset_uniform_group_pair, AugmentDataset_uniform_group_pair_aug
from .datasets_uniform_multiview import AugmentDataset_uniform_group_multiview
from .datasets_charades import CharadesEgoDataset
from .datasets_assembly101 import Assembly101


def worker_init_fn(worker_id: int) -> None:
    np.random.seed(np.random.get_state()[1][0] + worker_id)


dataset_classes = {
    'Assembly101': Assembly101,
    'CharadesEgo': CharadesEgoDataset,
    'CharadesEgo_1': CharadesEgoDataset,
    'CharadesEgo_2': CharadesEgoDataset,
    'CharadesEgo_3': CharadesEgoDataset,
        }


def get_dataset_by_name(dataset_name, config, split, transform=None):
    dataset_class = dataset_classes.get(dataset_name)
    if dataset_class is None:
        raise ValueError(f"Dataset '{dataset_name}' not found.")
    dataset_instance = dataset_class(config, split)
    return dataset_instance

def get_sub_config(config, dataset_name):
    if dataset_name == 'Assembly101':
        return config.assembly101
    elif dataset_name == 'CharadesEgo':
        return config.charades
    elif dataset_name == 'CharadesEgo_1':
        return config.charades_1
    elif dataset_name == 'CharadesEgo_2':
        return config.charades_2
    elif dataset_name == 'CharadesEgo_3':
        return config.charades_3
    else:
        raise ValueError(f"Unknown dataset name: {dataset_name}")


def create_dataloader(
        config: yacs.config.CfgNode,
        is_train: bool) -> Union[Tuple[DataLoader, DataLoader], DataLoader]:
    if config.sample_type == 'random':
        train_dataset = AugmentDataset(config, fold='train')
        test_dataset = AugmentDataset_val(config, fold='val')

    elif config.sample_type == 'uniform':
        train_dataset = AugmentDataset_uniform(config, fold='train')
        test_dataset = AugmentDataset_val_uniform(config, fold='val')
    ## the group here
    elif config.sample_type == 'uniform_group_pair': #
        if config.aug:
            print(f'---------using augmentation dataset!')
            train_dataset = []
            test_dataset = []
            for dataset_name in config.dataset_all:
                # print(f'----------------Dataset: {dataset_name}')
                config_diff = get_sub_config(config, dataset_name)
                all_config = [config, config_diff]
                dataset_single_train = get_dataset_by_name(dataset_name, all_config, 'train')
                # dataset_single_test = get_dataset_by_name(dataset_name, all_config, 'val')
                train_dataset.append(dataset_single_train)
                # test_dataset.append(dataset_single_test)
            train_dataset = ConcatDataset(train_dataset)
            # test_dataset = ConcatDataset(test_dataset)
        else:
            train_dataset = AugmentDataset_uniform_group_pair(config, fold='train')
            test_dataset = AugmentDataset_uniform_group_pair(config, fold='val')
    elif config.sample_type == 'uniform_group_multiview':
        train_dataset = AugmentDataset_uniform_group_multiview(config, fold='train')
        # test_dataset = AugmentDataset_uniform_group_multiview(config, fold='val')
        a =1

    ##---------distribute dataset or not
    if dist.is_available() and dist.is_initialized():
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        # test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                    batch_size=config.batch_size,
                                                    sampler=train_sampler,
                                                   num_workers=config.num_workers, #6
                                                   worker_init_fn=worker_init_fn,
                                                   pin_memory=True)
        # test_loader = torch.utils.data.DataLoader(test_dataset,
        #                                             batch_size=config.batch_size,
        #                                             sampler=test_sampler,
        #                                           num_workers=6,
        #                                           worker_init_fn=worker_init_fn,
        #                                           pin_memory=True)
    else:

        train_loader = DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=True,
                                  pin_memory=True, num_workers=16,  # collate_fn=collate_fn_override,
                                  worker_init_fn=worker_init_fn)
        # test_loader = DataLoader(dataset=test_dataset, batch_size=config.batch_size, shuffle=False,
        #                          pin_memory=True, num_workers=16,  # collate_fn=collate_fn_override,
        #                          worker_init_fn=worker_init_fn)
    # return train_loader, test_loader
    return train_loader, None




