import copy
import torch

from backbones.backbone_utils import BackboneEnum
from baselines.method_enum import MethodEnum
from data.dataset_utils import DatasetEnum
from subspace.meta_musml import MUSML
from utils.config_utils import ConfigUtils
from utils.progress_utils import ExptEnum, StageEnum
from utils.time_utils import TimeUtils
from utils.torch_utils import TorchUtils

# method

def classification_main(parser):
    index_key = TimeUtils.get_now_str(fmt=TimeUtils.YYYYMMDDHHMMSS_COMPACT)

    config_dict_template = ConfigUtils.get_basic_config()
    config_file = "subspace_config.yaml"
    config_dict_template.update(ConfigUtils.get_config_dict(config_file))

    seeds = parser.seeds
    ds_name = parser.ds
    experiment = parser.expt
    method_name = parser.method
    n_dim = parser.n_dim
    k = parser.k
    head_hidden_dim = parser.head_hidden_dim
    beta = parser.beta

    for seed in seeds:
        TorchUtils.set_random_seed(seed)

        config_dict = copy.deepcopy(config_dict_template)

        backbone_name = BackboneEnum.CONV4.name
        config_dict["backbone_name"] = backbone_name
        config_dict["experiment"] = experiment
        config_dict["index_key"] = index_key
        config_dict["n_dim"] = n_dim
        config_dict["k"] = k
        config_dict["job_type"] = parser.job_type
        config_dict["head_hidden_dim"] = head_hidden_dim
        config_dict["beta"] = beta
        config_dict["init"] = "uniform"
        config_dict["method_name"] = method_name
        config_dict["ds_name"] = ds_name
        job_id = "{}_{}_{}_{}_dim{}_K{}_headdim{}_we{}_{}_{}".format(method_name, ds_name, experiment, backbone_name, n_dim, k, head_hidden_dim, beta, parser.job_type, seed)
        config_dict["job_id"] = job_id

        method_class = MUSML
        method = method_class(config_dict)
        method.logger.info(config_dict)
        method.train(seed=seed)

import argparse

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    ## default
    gpu_id = 0
    ds = DatasetEnum.MetaDataset.name
    expt = ExptEnum.H_5_WAY_5_SHOT.value
    method = MethodEnum.SubspacePNPart2.name
    seed = 0
    job_type = "TT_{}".format(seed)
    seeds = [seed]

    ##
    parser.add_argument('--gpu_id', default=gpu_id, type=int)
    parser.add_argument('--n_dim', default=5, type=int)
    parser.add_argument('--k', default=20, type=int)
    parser.add_argument('--head_hidden_dim', default=512, type=int)
    parser.add_argument('--beta', default=0, type=int)
    parser.add_argument('--job_type', default=job_type, type=str)
    parser.add_argument('--ds', default=ds, type=str)
    parser.add_argument('--expt', default=expt, type=str)
    parser.add_argument('--method', default=method, type=str)
    parser.add_argument('--seed', default=seed, type=int)
    parser.add_argument('--seeds', default=seeds, nargs='+', type=int)
    args = parser.parse_args()
    print(args)

    device_id = args.gpu_id
    torch.cuda.set_device(device=device_id)
    classification_main(args)
