"""Default config"""
from utils import key_is_none

"""Dataset"""
dataset_defaults = {
    'amazon': {
        'recent_batches': 5,
        'model': 'distilbert-base-uncased',
        'transform': 'bert',
        'loss_function': 'xent',
        'max_token_length': 512,
        'batch_size': 8,
        'eval_metric': 'acc',
        'eval_batch_size': 8,
        'lr': 1e-5,
        'weight_decay': 0.01,
        'epochs': 5,
        'loader_kwargs': {'num_workers': 1, 'pin_memory': True},
    },
    'fmow': {
        'recent_batches': 6,
        'model': 'densenet121',
        'transform': 'image_base',
        'loss_function': 'xent',
        'optimizer': 'Adam',
        'lr': 0.0001,
        'weight_decay': 0.0,
        'scheduler': 'FractionStepLR',
        'scheduler_kwargs': {'gamma': 0.1, 'step_size': 0.4},  
        'batch_size': 32,
        'eval_metric': 'acc',
        'eval_batch_size': 32,
        'epochs': 30,
        'loader_kwargs': {'num_workers': 4, 'pin_memory': True},
        'randaugment_n': 2,
    },
    'civilcomments': {
        'recent_batches': 5,
        'model': 'distilbert-base-uncased',
        'transform': 'bert',
        'loss_function': 'xent',
        'max_token_length': 300,
        'batch_size': 16,
        'eval_metric': 'acc',
        'eval_batch_size': 16,
        'lr': 1e-5,
        'weight_decay': 0.01,
        'epochs': 10,
        'loader_kwargs': {'num_workers': 1, 'pin_memory': True},
    },
    'poverty': {
        'recent_batches': 5,
        'model': 'resnet18_ms',
        'model_kwargs': {'num_channels': 8},
        'transform': 'poverty',
        'loss_function': 'mse',
        'eval_metric': 'pearson',
        'optimizer': 'Adam',
        'scheduler': 'FractionStepLR',
        'scheduler_kwargs': {'gamma': 0.1, 'step_size': 0.4},
        'batch_size': 64,
        'eval_batch_size': 64,
        'lr': 0.001,
        'weight_decay': 0.0,
        'epochs': 200,
        'loader_kwargs': {'num_workers': 4, 'pin_memory': True},
        'randaugment_n': 2,
    },
}

"""Model"""
model_defaults = {
    'bert-base-uncased': {
        'optimizer': 'AdamW',
        'max_grad_norm': 1.0,
        'scheduler': 'linear_schedule_with_warmup',
    },
    'distilbert-base-uncased': {
        'optimizer': 'AdamW',
        'max_grad_norm': 1.0,
        'scheduler': 'linear_schedule_with_warmup',
    },
    'code-gpt-py': {
        'optimizer': 'AdamW',
        'max_grad_norm': 1.0,
        'scheduler': 'linear_schedule_with_warmup',
    },
    'densenet121': {
        'model_kwargs': {
            'pretrained':True,
        },
        'target_resolution': (224, 224),
    },
    'wideresnet50': {
        'model_kwargs': {
            'pretrained':True,
        },
        'target_resolution': (224, 224),
    },
    'resnet18': {
        'model_kwargs':{
            'pretrained':True,
        },
        'target_resolution': (224, 224),
    },
    'resnet34': {
        'model_kwargs':{
            'pretrained':True,
        },
        'target_resolution': (224, 224),
    },
    'resnet50': {
        'model_kwargs': {
            'pretrained': True,
        },
        'target_resolution': (224, 224),
    },
    'resnet101': {
        'model_kwargs': {
            'pretrained': True,
        },
        'target_resolution': (224, 224),
    },
    'gin-virtual': {},
    'resnet18_ms': {
        'target_resolution': (224, 224),
    },
    'logistic_regression': {},
    'unet-seq': {
        'optimizer': 'Adam'
    },
    'fasterrcnn': {
        'model_kwargs': {
            'pretrained_model': True,
            'pretrained_backbone': True,
            'min_size' :1024,
            'max_size' :1024
        }
    }
}

"""Scheduler"""
scheduler_defaults = {
    'none': {},
    'linear_schedule_with_warmup': {
        'scheduler_kwargs':{
            'num_warmup_steps': 0,
        },
    },
    'cosine_schedule_with_warmup': {
        'scheduler_kwargs':{
            'num_warmup_steps': 0,
        },
    },
    'ReduceLROnPlateau': {
        'scheduler_kwargs':{},
    },
    'StepLR': {
        'scheduler_kwargs':{
            'step_size': 1,
        }
    },
    'FractionStepLR': {
        'scheduler_kwargs': {
            'step_size': 1,
        }
    },
    'FixMatchLR': {
        'scheduler_kwargs': {},
    },
    'MultiStepLR': {
        'scheduler_kwargs':{
            'gamma': 0.1,
        }
    },
}


def populate_config(config, template: dict, force_compatibility=False):
    """Populates missing (key, val) pairs in config with (key, val) in template.
    Example usage: populate config with defaults
    Args:
        - config: namespace
        - template: dict
        - force_compatibility: option to raise errors if config.key != template[key]
    """
    if template is None:
        return config

    d_config = vars(config)
    for key, val in template.items():
        if not isinstance(val, dict): # config[key] expected to be a non-index-able
            if key_is_none(d_config, key):
                d_config[key] = val
            elif d_config[key] != val and force_compatibility:
                raise ValueError(f"Argument {key} must be set to {val}")

        else: # config[key] expected to be a kwarg dict
            for kwargs_key, kwargs_val in val.items():
                if key_is_none(d_config[key], kwargs_key):
                    d_config[key][kwargs_key] = kwargs_val
                elif d_config[key][kwargs_key] != kwargs_val and force_compatibility:
                    raise ValueError(f"Argument {key}[{kwargs_key}] must be set to {val}")
    return config


def populate_defaults(config):
    import torch
    config = populate_config(config, dataset_defaults[config.dataset])
    config = populate_config(config, model_defaults[config.model])
    if config.scheduler is not None:
        config = populate_config(config, scheduler_defaults[config.scheduler])
    if config.device is None:
        config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    return config