import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_inner_lr(P):
    return nn.Parameter(torch.tensor(P.inner_lr), requires_grad=True)


def setup(mode, P):
    if P.fname == None:
        if P.data_type == 'img':
            fname = f'{P.dataset}/{P.data_type}_{P.resolution:03}/instep:{P.inner_step:02}_initer:{P.inner_iter:02}/incremental:{P.incremental}_lr:{int(P.inner_lr*1e3):03d}:{int(P.lr*1e6):04d}_single:False'

        elif P.data_type == 'video':
            fname = f'main/{P.data_type}_{P.resolution:03}/instep:{P.inner_step:02}_initer:{P.inner_iter:02}/incremental:{P.incremental}_lr:{int(P.inner_lr*1e3):03d}:{int(P.lr*1e6):04d}_frozen:{P.frozen}'
    else:
        fname = P.fname

    if mode in ['fomaml', 'maml','none']:
        if P.data_type == 'img':
            from train.gradient_based.maml import train_step_img as train_step
            from train.gradient_based.maml import check
        elif P.data_type == 'video':
            from train.gradient_based.maml import train_step_video as train_step
            from train.gradient_based.maml import check
    else:
        raise NotImplementedError()

    today = check(P)
    if P.no_date:
        today = False

    if P.replay:
        fname += f'_replay'
    if P.oml:
        fname += f'_oml:{P.oml_layer}'
    if P.prog:
        fname += f'_prog:{P.oml_layer}'
    fname += f'_seed:{P.seed}'

    if P.suffix is not None:
        fname += f'_{P.suffix}'

    return train_step, fname, today
