from collections import OrderedDict

import torch
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def setup(mode, P):
    if P.regression:
        fname = f'{P.dataset}_{P.model}_{mode}_{P.num_shots}shot_{P.batch_size}task'
    else:
        fname = f'{P.dataset}_{P.model}_{mode}_{P.num_ways}way_{P.num_shots}shot_{P.batch_size}task'

    if mode in ['maml', 'metasgd', 'anil']:
        if P.barlow:
            if P.ablation:
                from train.gradient_based.maml_barlow_ablation import maml_barlow_step as train_func
                from train.gradient_based.maml_barlow_ablation import check
            else:
                from train.gradient_based.maml_barlow import maml_barlow_step as train_func
                from train.gradient_based.maml_barlow import check
        else:
            from train.gradient_based.maml import maml_step as train_func
            from train.gradient_based.maml import check
    else:
        raise NotImplementedError()

    today = check(P)
    if P.baseline:
        today = False

    fname += f'_seed_{P.seed}'
    if P.suffix is not None:
        fname += f'_{P.suffix}'

    if P.ema:
        fname += f'_ema'

    if P.barlow:
        fname += f'_barlow'
    
    if P.adml:
        fname += f'_adml'

    if P.adv:
        fname += f'_adv{P.attack_img_num}'
    
    if P.img_aug_only:
        fname += f'_ablation_img_aug'

    if P.class_attack:
        fname += f'_ablation_class_attack'

    fname += f'_{P.aug_type}'

    return train_func, fname, today


def copy_model_param(model, params=None):
    if params is None:
        params = OrderedDict(model.meta_named_parameters())
    copy_params = OrderedDict()

    for (name, param) in params.items():
        copy_params[name] = param.clone().detach()
        copy_params[name].requires_grad_()

    return copy_params


def param_ema(P, model):
    params = OrderedDict(model.meta_named_parameters())

    for (name, param) in params.items():
        P.moving_average[name].data = P.eta * P.moving_average[name].data + (1 - P.eta) * param.data

    if 'metasgd' in P.mode:
        for (name, param) in P.inner_lr.items():
            P.moving_inner_lr[name].data = P.eta * P.moving_inner_lr[name].data + (1 - P.eta) * param.data

