import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from data.fashionIQ import FashionIQDataset, FashionIQUserSurveyDataset, FashionIQUserSurveyTestSampleDataset
from data.cirr import CIRRDataset
from data.collate_fns import PaddingCollateFunction, PaddingCollateFunctionTest, BertPaddingCollateFunction, BertPaddingCollateFunctionTest, BLIPPaddingCollateFunction, BLIPPaddingCollateFunctionTest, BLIPPaddingCollateFunctionTest4CIRR

def train_dataset_factory(transforms, config):
    image_transform = transforms['image_transform']
    dataset_code = config['dataset']
    use_subset = config.get('use_subset', False)

    if FashionIQDataset.code() in dataset_code:
        # concat subsets of FashionIQ      
        dataset  = FashionIQDataset(split='train', dress_types=FashionIQDataset.all_subset_codes(), mode='relative', preprocess=None, config=config)

    elif CIRRDataset.code() in dataset_code:
        dataset = CIRRDataset('train', 'relative', None, config)

    else:
        raise ValueError("There's no {} dataset".format(dataset_code))

    return dataset


def test_dataset_factory(transforms, config, split='val'):
    image_transform = transforms['image_transform']
    # text_transform = None if config['text_encoder'] == 'roberta' else transforms['text_transform']
    dataset_code = config['dataset']
    test_datasets = {}

    if FashionIQDataset.code() in dataset_code:
        if config['experiment'] == 'CIRQRS':
            for dress_type in FashionIQDataset.all_subset_codes():
                test_datasets['fashionIQ_' + dress_type] = FashionIQDataset(split="test", dress_types=[dress_type],
                                                                               mode="relative", preprocess=None,
                                                                               config=config)
        elif config['experiment'] == 'CIRQRS_Recall':
            for dress_type in FashionIQDataset.all_subset_codes():
                test_datasets['fashionIQ_' + dress_type] = {
                    "samples": FashionIQDataset(split="val", dress_types=[dress_type],
                                                    mode="classic", preprocess=None,
                                                    config=config),
                    "query": FashionIQDataset(split="val", dress_types=[dress_type],
                                                    mode="relative", preprocess=None,
                                                    config=config)
                }

        else:
            raise ValueError("Not Related to the CIRQRS")

    elif CIRRDataset.code() in dataset_code:
        if config['experiment'] == 'CIRQRS':
            test_datasets[CIRRDataset.code()] = CIRRDataset('test1', 'relative', None, config)

        elif config['experiment'] == 'CIRQRS_Recall':
            test_datasets[CIRRDataset.code()] = {
                "samples": CIRRDataset('test1', 'classic', None, config),
                "query": CIRRDataset('test1', 'relative', None, config)
            }

        elif config['experiment'] == 'User_Survey':
            test_datasets[CIRRDataset.code()] = {
                "samples": CIRRDataset('val', 'classic', None, config),
                "query": CIRRDataset('val', 'relative', None, config)
            }

        else:
            raise ValueError("Not Related to the CIRQRS")

    if len(test_datasets) == 0:
        raise ValueError("There's no {} dataset".format(dataset_code))

    return test_datasets


def train_dataloader_factory(dataset, config, collate_fn=None):
    batch_size = config['batch_size']
    num_workers = config.get('num_workers', 16)
    shuffle = config.get('shuffle', True)

    return DataLoader(dataset, batch_size, shuffle=True, num_workers=num_workers, pin_memory=True,
                      collate_fn=collate_fn)

def cirqrs_test_dataloader_factory(dataset, config, collate_fn=None):
    batch_size = config['batch_size']
    num_workers = config.get('num_workers', 16)
    shuffle = config.get('shuffle', True)

    return DataLoader(dataset, batch_size, shuffle=False, num_workers=num_workers, pin_memory=True,
                      collate_fn=collate_fn)

def cirqrs_test_recall_dataloader_factory(datasets, config, collate_fn=None):
    batch_size = config['batch_size']
    num_workers = config.get('num_workers', 16)
    shuffle = False

    return {
        'query': DataLoader(datasets['query'], batch_size, shuffle=False, num_workers=num_workers, pin_memory=True,
                            collate_fn=collate_fn),
        'samples': DataLoader(datasets['samples'], batch_size, shuffle=False, num_workers=num_workers,
                              pin_memory=True,
                              collate_fn=collate_fn)
    }


def test_dataloader_factory(datasets, config, collate_fn=None):
    batch_size = config['batch_size']
    num_workers = config.get('num_workers', 16)
    shuffle = False

    return {
        'query': DataLoader(datasets['query'], batch_size, shuffle=False, num_workers=num_workers, pin_memory=True,
                            collate_fn=collate_fn),
        'samples': DataLoader(datasets['samples'], batch_size, shuffle=False, num_workers=num_workers,
                              pin_memory=True,
                              collate_fn=collate_fn)
    }

def create_dataloaders(image_transform, text_transform, configs):
    train_dataset = train_dataset_factory(
        transforms={'image_transform': image_transform['train'], 'text_transform': text_transform},
        config=configs)
    test_datasets = test_dataset_factory(
        transforms={'image_transform': image_transform['val'], 'text_transform': text_transform},
        config=configs)
    train_val_datasets = test_dataset_factory(
        transforms={'image_transform': image_transform['val'], 'text_transform': text_transform},
        config=configs, split='train')
    
    if configs['text_encoder'] == 'roberta':
        padding_idx = 1
        collate_fn = BertPaddingCollateFunction(padding_idx=padding_idx)
        collate_fn_test = BertPaddingCollateFunctionTest(padding_idx=padding_idx)

    # For Training Matching Score using BLIP Structure & Negative as randomly sampled image
    elif configs['text_encoder'] == 'blip':
        collate_fn = BLIPPaddingCollateFunction()
        if configs['dataset'] != 'cirr':
            collate_fn_test = BLIPPaddingCollateFunctionTest()
        else:
            collate_fn_test = BLIPPaddingCollateFunctionTest4CIRR()

    # For Training Matching Score & Evaluating CIRQRS (Except Recall@k Evaluation)
    if configs['experiment'] == 'CIRQRS':
        train_dataloader = train_dataloader_factory(dataset=train_dataset, config=configs, collate_fn=collate_fn)
        test_dataloaders = {
            key: cirqrs_test_dataloader_factory(dataset=value, config=configs, collate_fn=collate_fn_test) for
            key, value in test_datasets.items()}

        train_val_dataloaders = {
            key: cirqrs_test_dataloader_factory(dataset=value, config=configs, collate_fn=collate_fn_test) for
            key, value in test_datasets.items()}

    # Evaluating CIRQRS as CIR model (Recall@K)
    elif configs['experiment'] == 'CIRQRS_Recall':
        train_dataloader = train_dataloader_factory(dataset=train_dataset, config=configs, collate_fn=collate_fn)
        test_dataloaders = {key: cirqrs_test_recall_dataloader_factory(datasets=value, config=configs, collate_fn=collate_fn_test) for
                            key, value in test_datasets.items()}
        train_val_dataloaders = {
            key: cirqrs_test_recall_dataloader_factory(datasets=value, config=configs, collate_fn=collate_fn_test) for
            key, value in test_datasets.items()}

    else:
        train_dataloader = train_dataloader_factory(dataset=train_dataset, config=configs, collate_fn=collate_fn)
        test_dataloaders = {key: test_dataloader_factory(datasets=value, config=configs, collate_fn=collate_fn_test) for
                            key, value in test_datasets.items()}
        train_val_dataloaders = {
            key: test_dataloader_factory(datasets=value, config=configs, collate_fn=collate_fn_test) for key, value in
            train_val_datasets.items()}


    return train_dataloader, test_dataloaders, train_val_dataloaders
