import torch
from torch.utils.data import DataLoader, DistributedSampler
from ltr.data.test_data_builder import *
from ltr.dataset import Lasot, Lang_OTB, TNL2K
from ltr.data import transforms as tfm
from ltr.data import loader, processing, sampler


def create_train_dataloaders(config):
    if config.train.ddp.istrue:
        gpu_num = world_size = torch.cuda.device_count() # or use world_size = torch.cuda.device_count()
        gpus = list(range(0, gpu_num))
    else:
        gpu_num = 1

    local_rank = config.train.ddp.local_rank

    datasets = []
    for name in config.train.dataset.which_use:
        if name=='tnl2k':
            dataset = TNL2K(config.train.dataset.config.tnl2k.path, split='train')
        elif name=='lasot':
            dataset = Lasot(config.train.dataset.config.lasot.path, split='train')
        elif name=='otb99':
            dataset = Lang_OTB(config.train.dataset.config.otb99.path, split='train')
        datasets.append(dataset)

    output_sz = {'template': config.train.template_size,
                 'search': config.train.search_size}
    area_factor = {'template': config.train.template_area_factor,
                 'search': config.train.search_area_factor}

    # The joint augmentation transform, that is applied to the pairs jointly
    transform_joint = tfm.Transform(tfm.ToGrayscale(probability=0.05),
                                    tfm.RandomHorizontalFlip(probability=0.5))

    # The augmentation transform applied to the training set (individually to each image in the pair)
    transform_train = tfm.Transform(tfm.ToTensorAndJitter(0.2),
                                    tfm.RandomHorizontalFlip_Norm(probability=0.5),
                                    tfm.Normalize(mean=config.train.dataset.mean, std=config.train.dataset.std))

    # Data processing to do on the training pairs
    data_processing_train = processing.TrackerProcessing(
                                area_factor=area_factor,
                                output_sz=output_sz,
                                center_jitter_factor=config.train.center_jitter_factor,
                                scale_jitter_factor=config.train.scale_jitter_factor,
                                mode='sequence',
                                transform=transform_train,
                                joint_transform=transform_joint)

    dataset_train = sampler.TrackerSampler(
                                datasets=datasets, 
                                p_datasets=config.train.dataset.probability,
                                samples_per_epoch=config.train.dataset.total_num, max_gap=200, 
                                num_search_frames=1,num_template_frames=2,
                                processing=data_processing_train,
                                bert_seq_length=config.model.bert.max_len)

    if config.train.ddp.istrue:
        train_sampler = DistributedSampler(dataset_train)
        shuffle = False
    # The loader for training
    else:
        train_sampler = None
        shuffle = True
    num_workers = (config.train.workers//torch.cuda.device_count())*torch.cuda.device_count()
    loader_train = loader.LTRLoader('train', dataset_train, training=True, batch_size=config.train.batch, num_workers=num_workers, shuffle=shuffle, drop_last=True, stack_dim=0, sampler=train_sampler)

    return loader_train


def create_val_datasets(cfg):
    
    dataset_name = cfg.test.data
    if dataset_name=='lasot':
        dataset = LaSOT_Dataset(cfg)
    elif dataset_name=='otb99':
        dataset = OTB99_Dataset(cfg)
    elif dataset_name=='tnl2k':
        dataset = TNL2K_Dataset(cfg)

    return dataset, dataset_name