# Based on implementations from the 4M repo: https://github.com/apple/ml-4m/
import copy

import torch
import yaml

import fourm.utils as utils

from fourm.data import (CenterCropImageAugmenter, EmptyAugmenter,
                  PreTokenizedImageAugmenter,RandomCropImageAugmenter, build_fm_pretraining_dataset,
                  build_huggingface_pretraining_dataloader,
                  build_wds_fm_pretraining_dataloader)
from fourm.data.modality_transforms import CaptionTransform
from fourm.data.modality_info import MODALITY_TRANSFORMS


def setup_sampling_mod_info(dataset_config, modality_info):
    # Subset of modality info for each dataset

    # Input and output modalities for one dataset
    in_domains = sorted(dataset_config['in_domains'].split('-'))
    out_domains = sorted(dataset_config['out_domains'].split('-'))
    all_domains = sorted(list(set(in_domains) | set(out_domains)))

    mod_info = copy.deepcopy(modality_info)
    mod_info = {mod: mod_info[mod] for mod in all_domains}

    # Dirichlet concentration parameter (Alpha)
    if dataset_config.get('alphas_config', None) is None:
        for mod in mod_info:
            mod_info[mod]["input_alphas"] = [0.]
            mod_info[mod]["target_alphas"] = [0.]

        if 'input_alphas' in dataset_config:
            input_alphas = dataset_config['input_alphas'].split('-')
            if len(input_alphas) == 1:
                input_alphas = [float(input_alphas[0])] * len(in_domains)
            else:
                input_alphas = [float(alpha) for alpha in input_alphas]
            for mod, alpha in zip(in_domains, input_alphas):
                mod_info[mod]['input_alphas'] = [alpha]

        if 'target_alphas' in dataset_config:
            target_alphas = dataset_config['target_alphas'].split('-')
            if len(target_alphas) == 1:
                target_alphas = [float(target_alphas[0])] * len(out_domains)
            else:
                target_alphas = [float(alpha) for alpha in target_alphas]
            for mod, alpha in zip(out_domains, target_alphas):
                mod_info[mod]["target_alphas"] = [alpha]

        sampling_weights = None
    else:
        print(f"Loading alphas config from: {dataset_config['alphas_config']}")
        with open(dataset_config['alphas_config'], "r") as f:
            alphas_config = yaml.safe_load(f)

        if 'sampling_weights' in alphas_config:
            sampling_weights = alphas_config['sampling_weights']
            alphas_config = alphas_config['alphas_mixture']
        else:
            sampling_weights = None
        
        for mod in mod_info:
            mod_info[mod]["input_alphas"] = alphas_config[mod]["input_alphas"]
            mod_info[mod]["target_alphas"] = alphas_config[mod]["target_alphas"]
            if modality_info[mod]['type'] in ['seq', 'seq_emb', 'seq_token']:
                mod_info[mod]['keep'] = alphas_config[mod]['keep']
    
    return mod_info, sampling_weights

def get_train_dataloader(dataset_config, modality_info, sampling_weights, text_tokenizer, input_size, 
                         num_input_tokens, num_target_tokens, min_input_tokens, min_target_tokens,
                         num_tasks, num_workers, dataset_batch_size=None, epoch_size=None):
    
    in_domains = sorted(list(dataset_config['in_domains'].split('-')))
    out_domains = sorted(list(dataset_config['out_domains'].split('-')))
    all_domains = sorted(list(set(in_domains) | set(out_domains)))

    modality_transforms = MODALITY_TRANSFORMS
    if 'caption' in modality_transforms:
        modality_transforms['caption'] = CaptionTransform(
            aligned_captions=dataset_config.get('aligned_captions', True)
        )
    
    if dataset_config['type'] == 'multimodal':

        is_pretokenized = any([modality_info[mod].get('pretokenized', False) for mod in modality_info])
        if is_pretokenized:
            # Multi-modal training data augmentation (uses pre-tokenized data augmentation)
            image_augmenter = PreTokenizedImageAugmenter(
                target_size=input_size, 
                no_aug=(not dataset_config.get('tok_train_aug', True)), 
                main_domain=dataset_config['main_augment_domain']
            )
        else:
            image_augmenter = RandomCropImageAugmenter(
                 target_size=input_size, 
                 hflip=dataset_config.get('hflip'), 
                 crop_scale=tuple(dataset_config.get('crop_scale')),
                 crop_ratio=tuple(dataset_config.get('crop_ratio')),
            )

        # Input and target token ranges
        num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens)
        num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens)
        min_input_tokens = dataset_config.get('min_input_tokens', min_input_tokens)
        min_target_tokens = dataset_config.get('min_target_tokens', min_target_tokens)
        min_input_tokens = num_input_tokens if min_input_tokens is None else min_input_tokens
        min_target_tokens = num_target_tokens if min_target_tokens is None else min_target_tokens


        if dataset_config['use_wds']:
            # Using webdataset
            loader = build_wds_fm_pretraining_dataloader(
                data_path=dataset_config['data_path'], all_domains=all_domains,
                modality_info=modality_info, modality_transforms=modality_transforms,
                image_augmenter=image_augmenter, text_tokenizer=text_tokenizer,
                input_tokens_range=(min_input_tokens, num_input_tokens),
                target_tokens_range=(min_target_tokens, num_target_tokens),
                num_gpus=num_tasks, num_workers=num_workers,
                batch_size=dataset_batch_size, epoch_size=epoch_size,
                modality_name_map=dataset_config.get('modality_name_map', None),
                shuffle_buffer_load=dataset_config.get('wds_shuffle_buffer_tar', 1_000),
                shuffle_buffer_repeat=dataset_config.get('wds_shuffle_buffer_repeat', 1_000),
                n_repeats=dataset_config.get('wds_n_repeats', 1),
                sampling_weights=sampling_weights,
            )
        else:
            dataset_train = build_fm_pretraining_dataset(
                data_path=dataset_config['data_path'], 
                all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms,
                image_augmenter=image_augmenter, text_tokenizer=text_tokenizer,
                input_tokens_range=(min_input_tokens, num_input_tokens), 
                target_tokens_range=(min_target_tokens, num_target_tokens)
            )
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train, num_replicas=num_tasks, rank=utils.get_rank(), shuffle=True, drop_last=True,
            )
            # DataLoader has batch size 1 as it then gets collated through the Mixture dataloader
            loader = torch.utils.data.DataLoader(
                dataset_train, sampler=sampler_train,
                batch_size=1, num_workers=0,
                pin_memory=False, drop_last=True,
                collate_fn=lambda x: x[0],
            )

    elif dataset_config['type'] == 'huggingface':

        # Input and target token ranges
        num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens)
        num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens)
        
        if dataset_config.get('use_wds', False):
            raise NotImplementedError('Webdataset not yet implemented for huggingface datasets.')
        else:
            loader = build_huggingface_pretraining_dataloader(
                data_path=dataset_config['data_path'], all_domains=all_domains,
                modality_info=modality_info, modality_transforms=modality_transforms,
                image_augmenter=EmptyAugmenter(), text_tokenizer=text_tokenizer,
                input_tokens_range=(num_input_tokens, num_input_tokens), 
                target_tokens_range=(num_target_tokens, num_target_tokens),
                num_gpus=num_tasks, num_workers=num_workers,
                batch_size=dataset_batch_size, epoch_size=epoch_size,
                split='train', streaming=True, rename_text_to_caption=True, 
                shuffle_buffer_load=dataset_config.get('shuffle_buffer_load', 1_000),
                shuffle_seed=0,
            )
    else:
        raise NotImplementedError(f'Dataset type {dataset_config["type"]} not implemented.')

    return loader
    

def cfgs_get(key, val_config, dataset_name, train_configs, default=None):
    """ Try to retrieve a key from the validation set config.
    If it does not exist, default to retrieving it from the train set config
    with the same dataset name.
    """
    return val_config.get(key, train_configs[dataset_name].get(key, default))


def get_val_dataloader(dataset_config, dataset_name, train_configs, modality_info, sampling_weights, text_tokenizer, 
                       input_size, num_input_tokens, num_target_tokens, min_input_tokens, min_target_tokens,
                       fixed_eval, fixed_eval_input_tokens, fixed_eval_target_tokens, 
                       dist_eval, num_tasks, num_workers, batch_size, pin_mem):
    
    in_domains = sorted(list(cfgs_get('in_domains', dataset_config, dataset_name, train_configs).split('-')))
    out_domains = sorted(list(cfgs_get('out_domains', dataset_config, dataset_name, train_configs).split('-')))
    all_domains = sorted(list(set(in_domains) | set(out_domains)))

    modality_transforms = MODALITY_TRANSFORMS
    if 'caption' in modality_transforms:
        modality_transforms['caption'] = CaptionTransform(
            aligned_captions=cfgs_get('aligned_captions', dataset_config, dataset_name, train_configs, True)
        )

    dataset_type = cfgs_get('type', dataset_config, dataset_name, train_configs)

    if dataset_type == 'multimodal':

        main_augment_domain = cfgs_get('main_augment_domain', dataset_config, dataset_name, train_configs)
        is_pretokenized = any([modality_info[mod].get('pretokenized', False) for mod in modality_info])
        if is_pretokenized:
            eval_image_augmenter = PreTokenizedImageAugmenter(
                target_size=input_size, no_aug=True, main_domain=main_augment_domain
            )
        else:
            eval_image_augmenter = CenterCropImageAugmenter(
                target_size=input_size, main_domain=main_augment_domain
            )
            

        if fixed_eval:
            input_tokens_range=(fixed_eval_input_tokens, fixed_eval_input_tokens)
            target_tokens_range=(fixed_eval_target_tokens, fixed_eval_target_tokens)
        else:
            # Input and target token ranges
            num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens)
            num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens)
            min_input_tokens = dataset_config.get('min_input_tokens', min_input_tokens)
            min_target_tokens = dataset_config.get('min_target_tokens', min_target_tokens)
            min_input_tokens = num_input_tokens if min_input_tokens is None else min_input_tokens
            min_target_tokens = num_target_tokens if min_target_tokens is None else min_target_tokens
            input_tokens_range = (min_input_tokens, num_input_tokens)
            target_tokens_range = (min_target_tokens, num_target_tokens)

        dataset_val = build_fm_pretraining_dataset(
            data_path=cfgs_get('data_path', dataset_config, dataset_name, train_configs), 
            all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms,
            image_augmenter=eval_image_augmenter, text_tokenizer=text_tokenizer,
            input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range
        )

        print("Warning: Eval stats may vary slightly as the masking applied on images is random.")
        if dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                    'This will slightly alter validation results as extra duplicate entries are added to achieve '
                    'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val, num_replicas=num_tasks, rank=utils.get_rank(), shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
        loader = torch.utils.data.DataLoader(
            dataset_val, sampler=sampler_val,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=pin_mem,
            drop_last=False,
        )

    elif dataset_type == 'huggingface':

        if fixed_eval:
            input_tokens_range=(fixed_eval_input_tokens, fixed_eval_input_tokens)
            target_tokens_range=(fixed_eval_target_tokens, fixed_eval_target_tokens)
        else:
            # Input and target token ranges
            num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens)
            num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens)
            input_tokens_range = (num_input_tokens, num_input_tokens)
            target_tokens_range = (num_target_tokens, num_target_tokens)

        loader = build_huggingface_pretraining_dataloader(
            data_path=cfgs_get('data_path', dataset_config, dataset_name, train_configs), 
            all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms,
            image_augmenter=EmptyAugmenter(), text_tokenizer=text_tokenizer,
            input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range,
            num_gpus=num_tasks, num_workers=num_workers,
            batch_size=batch_size, epoch_size=None,
            split='validation', streaming=True, rename_text_to_caption=True, 
            shuffle_buffer_load=cfgs_get('shuffle_buffer_load', dataset_config, dataset_name, train_configs, 1_000),
            shuffle_seed=0,
        )

    else:
        raise NotImplementedError(f'Dataset type {dataset_type} not implemented.')

    return loader