import importlib
import numpy as np
import random
import torch
import torch.utils.data
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)  # Enforce the use of deterministic algorithms
from copy import deepcopy
from functools import partial
from os import path as osp

from basicsr.data.prefetch_dataloader import PrefetchDataLoader
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import DATASET_REGISTRY

__all__ = ['build_dataset', 'build_dataloader']

# automatically scan and import dataset modules for registry
# scan all the files under the data folder with '_dataset' in file names
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]


def build_dataset(dataset_opt):
    """Build dataset from options.

    Args:
        dataset_opt (dict): Configuration for dataset. It must contain:
            name (str): Dataset name.
            type (str): Dataset type.
    """
    dataset_opt = deepcopy(dataset_opt)
    dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
    logger = get_root_logger()
    logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
    return dataset


def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None,  args=None):
    """Build dataloader.

    Args:
        dataset (torch.utils.data.Dataset): Dataset.
        dataset_opt (dict): Dataset options. It contains the following keys:
            phase (str): 'train' or 'val'.
            num_worker_per_gpu (int): Number of workers for each GPU.
            batch_size_per_gpu (int): Training batch size for each GPU.
        num_gpu (int): Number of GPUs. Used only in the train phase.
            Default: 1.
        dist (bool): Whether in distributed training. Used only in the train
            phase. Default: False.
        sampler (torch.utils.data.sampler): Data sampler. Default: None.
        seed (int | None): Seed. Default: None
    """
    phase = dataset_opt['phase']
    rank, _ = get_dist_info()

    if phase == 'train':
        # if dist:  # distributed training# (Multi-GPU)
        #     batch_size = batch_size
        #     num_workers = args.n_threads
        # else:  # non-distributed training# (Single GPU)
        #     multiplier = 1 if num_gpu == 0 else num_gpu
        #     batch_size = batch_size * multiplier
        #     num_workers = args.n_threads * multiplier
        dataloader_train_args = dict(
            dataset=dataset,
            batch_size=args.batch_size_update,
            shuffle=False,
            num_workers=args.n_threads,
            pin_memory=not args.cpu,
            sampler=sampler[0],
            drop_last=False)
        dataloader_init_args = dict(
            dataset=dataset,
            batch_size=args.batch_size_calib,
            shuffle=False,
            num_workers=args.n_threads,
            pin_memory=not args.cpu,
            sampler=sampler[1],
            drop_last=False)
        if sampler is None:
            dataloader_train_args['shuffle'] = True
            dataloader_init_args['shuffle'] = True
        dataloader_train_args['worker_init_fn'] = partial(
            worker_init_fn, num_workers=args.n_threads, rank=rank, seed=seed) if seed is not None else None
        dataloader_init_args['worker_init_fn'] = partial(
            worker_init_fn, num_workers=args.n_threads, rank=rank, seed=seed) if seed is not None else None
        dataloader_train_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
        dataloader_init_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
    elif phase in ['val', 'test']:  # validation
        dataloader_test_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=args.n_threads, pin_memory=not args.cpu)
        dataloader_test_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
    else:
        raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")


    prefetch_mode = dataset_opt.get('prefetch_mode')  # None
    if prefetch_mode == 'cpu':  # CPUPrefetcher
        num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
        logger = get_root_logger()
        logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
        if phase == 'train':
            dataloader_train_args['prefetch_queue'] = num_prefetch_queue
            loader_train =  PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_train_args)
            loader_init = PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_init_args)
            return loader_train,loader_init
        else:
            dataloader_test_args['prefetch_queue'] = num_prefetch_queue
            loader_test = PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_test_args)
            return loader_test

    else:
        # prefetch_mode=None: Normal dataloader
        # prefetch_mode='cuda': dataloader for CUDAPrefetcher
        if phase == 'train':
            loader_train = torch.utils.data.DataLoader(**dataloader_train_args)
            loader_init = torch.utils.data.DataLoader(**dataloader_init_args)
            return loader_train,loader_init
        else:
            loader_test = torch.utils.data.DataLoader(**dataloader_test_args)
            return loader_test

def worker_init_fn(worker_id, num_workers, rank, seed):
    # Set the worker seed to num_workers * rank + worker_id + seed
    worker_seed = num_workers * rank + worker_id + seed
    np.random.seed(worker_seed)
    random.seed(worker_seed)
