# Config the setting of a specific dataset.
import torchvision
import torchvision.transforms as transforms
import logging
import os

#######################################################################
# Dataset Class Definition
#######################################################################

class MOTDataset:
    # Basic attributes
    basefolder_fullpath: str = ""
    n_tot_class: int = 0
    # Parameter attributes
    lr_param: dict = None
    vit_param: dict = None
    t_param: dict = None
    # Data transform attributes
    pixel_sz: int = 0
    transform_train: torchvision.transforms.Compose = None
    transform_test: torchvision.transforms.Compose = None

    def __init__(self, exp):
        pass
    def print_attributes(self):
        attributes = vars(self)
        logger = logging.getLogger()
        logger.info(f'[Dataset setting]')
        for key, value in attributes.items():
            if isinstance(value, torchvision.transforms.Compose):
                logger.info(f'\t{key:<20}: {self._print_transform(value)}')
            else:
                logger.info(f'\t{key:<20}: {value}')
    def _print_transform(self, transform):
        # Returns a string representation of the transforms
        transform_list = [str(t) for t in transform.transforms]
        return ', '.join(transform_list)

#######################################################################
# Utility Functions
#######################################################################

def _get_dataset_stats(dataset):
    """
    Return the std/mean value for known dataset. 
    """
    CIFAR100_TRAIN_MEAN = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
    CIFAR100_TRAIN_STD = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]

    CIFAR10_TRAIN_MEAN = [0.49139968, 0.48215827, 0.44653124]
    CIFAR10_TRAIN_STD = [0.24703233, 0.24348505, 0.26158768]

    if dataset=='cifar_10':
        return CIFAR10_TRAIN_MEAN, CIFAR10_TRAIN_STD
    elif dataset=='cifar_100':
        return CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD
    else:
        logger = logging.getLogger()
        logger.error(f'Dataset {dataset}: mean/std unknown.')
        raise ValueError('Unknown dataset')

#######################################################################
# Specific Dataset Initializer
#######################################################################

def init_cifar_10(exp):
    # init instance
    ds = MOTDataset(exp)
    ds.basefolder_fullpath = os.path.join(exp.path_dataset_prefix,"cifar-10/")
    # config
    ds.n_tot_class = 10
    ds.lr_param = {
        'Round': 600,
        'TrBatchSz': 256,
        'TsBatchSz': 256,
        'BaseLR': 8e-4,
        'WDecay': 1e-2,
        'lambda_r': 0.1, 
    }
    ds.t_param = {
        "stage1_last_round": 50,
        "stage2_last_round": 450,
        "gating_r": 1.0,
    }
    ds.vit_param = {
        "patch_size": 4,  
        "hidden_size": 48,
        "num_hidden_layers": 4,
        "num_attention_heads": 4,
        "intermediate_size": 4 * 48,
        "hidden_dropout_prob": 0.0,
        "attention_probs_dropout_prob": 0.0,
        "initializer_range": 0.02,
        "image_size": 32,
        "num_classes": ds.n_tot_class,
        "num_channels": 3,
        "qkv_bias": True,
        'fg_round': 150,
        "use_faster_attention": True,
        "lambda_e": 0.1,
        "n_expert": exp.args.n_expert,
    }

    # data transform for train/test set
    _mean, _std = _get_dataset_stats('cifar_10')
    ds.pixel_sz = 32
    ds.transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2),
        transforms.Normalize(_mean, _std),
    ])
    ds.transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32)),
        transforms.Normalize(_mean, _std),
    ])

    return ds


def init_cifar_100(exp):
    # init instance
    ds = MOTDataset(exp)
    ds.basefolder_fullpath = os.path.join(exp.path_dataset_prefix,"cifar-100/")

    ds.n_tot_class = 100

    ds.lr_param = {
        'Round': 300,
        'TrBatchSz': 256,
        'TsBatchSz': 256,
        'BaseLR': 1e-3,
        'WDecay': 1e-2,
        'lambda_r': 0.05, 
    }
    ds.t_param = {
        "stage1_last_round": 50,
        "stage2_last_round": 250,
        "gating_r": 1.0,
    }
    ds.vit_param = {
        "patch_size": 4,  
        "hidden_size": 48,
        "num_hidden_layers": 4,
        "num_attention_heads": 4,
        "intermediate_size": 4 * 48,
        "hidden_dropout_prob": 0.0,
        "attention_probs_dropout_prob": 0.0,
        "initializer_range": 0.02,
        "image_size": 32,
        "num_classes": ds.n_tot_class,
        "num_channels": 3,
        "qkv_bias": True,
        'fg_round': 50,
        "use_faster_attention": True,
        "lambda_e": 0,
        "n_expert": exp.args.n_expert,
    }

    # data transform for train/test set
    _mean, _std = _get_dataset_stats('cifar_100')
    ds.pixel_sz = 32
    ds.transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32)),
        transforms.Normalize(_mean, _std),
    ])
    ds.transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32)),
        transforms.Normalize(_mean, _std),
    ])
    
    return ds

#######################################################################
# Gateway
#######################################################################

def init_dataset(exp):
    logger = logging.getLogger()

    ############### Standard Dataset ###############
    if exp.args.dataset == "cifar_10":
        ds = init_cifar_10(exp)
    elif exp.args.dataset == "cifar_100":
        ds = init_cifar_100(exp)
    else:
        logger.error(f'Dataset {exp.args.dataset} not supported.')
        raise ValueError('Unknown Dataset')
    
    ################   Validation   ################
    assert ds.vit_param["hidden_size"] % ds.vit_param["num_attention_heads"] == 0
    assert ds.vit_param['image_size'] % ds.vit_param['patch_size'] == 0
    assert ds.t_param['stage1_last_round'] < ds.t_param['stage2_last_round']
    assert ds.t_param['stage2_last_round'] < ds.lr_param['Round']

    return ds
