import os

from runfiles.MyClock import MyClock
from tools.get_path import *

py3 = "~/anaconda3/envs/ava/bin/python "


def micro_auc_1dataset_usage(des,dataset_name,measure,
                                 py_file,seed_list=[2222], epoch=30, batch_size=128, opt="adam", cu=3, lr=0.0005,
                                 model="res50",num_workers=0,weight_decay_list=[1e-5],loss_modes=["u_1"]):
    cmds = []
    for seed in seed_list:
        for weight_decay in weight_decay_list:
            for loss_mode in loss_modes:
                log = get_project_path() + f"/train_logs/multilabel1set/{dataset_name}/{des}-{loss_mode}-{measure}-{opt}-{weight_decay}/e{epoch}-s{seed}-b{batch_size}"

                template_g = py3 + \
                             get_project_path() + f"/examples/micro-auc/{py_file}" + \
                             f" --dataset_name={dataset_name} --measure={measure} --opt={opt} " \
                             f"--cuda={cu} --lr={lr} --batch_size={batch_size} --seed={seed} --epoch={epoch} --csvlog_dir={log}" \
                             f" --model={model} --num_workers={num_workers} --weight_decay={weight_decay} --loss_mode={loss_mode}"

                cmds.append(template_g)

    mc = MyClock()

    for c in cmds:
        os.system(c)
        mc.cal_time()


def micro_auc_1dataset_multiclass_usage(des, dataset_name, measure,
                             py_file, seed_list=[2222], epoch=30, batch_size=128, opt="adam", cu=3, lr=0.0005,
                             model="res50", num_workers=0, weight_decay_list=[1e-5], loss_modes=["u_1"]):
    cmds = []
    for seed in seed_list:
        for weight_decay in weight_decay_list:
            for loss_mode in loss_modes:
                log = get_project_path() + f"/train_logs/multilabel1set/{dataset_name}/{des}-{loss_mode}-{measure}-{opt}-{weight_decay}/e{epoch}-s{seed}-b{batch_size}"

                template_g = py3 + \
                             get_project_path() + f"/examples/multi_class_micro_auc/{py_file}" + \
                             f" --dataset_name={dataset_name} --measure={measure} --opt={opt} " \
                             f"--cuda={cu} --lr={lr} --batch_size={batch_size} --seed={seed} --epoch={epoch} --csvlog_dir={log}" \
                             f" --model={model} --num_workers={num_workers} --weight_decay={weight_decay} --loss_mode={loss_mode}"

                cmds.append(template_g)

    mc = MyClock()

    for c in cmds:
        os.system(c)
        mc.cal_time()

def tabular_usage(datasets,modes,model,lr,weight_decays,cuda,epoch,n_hidden):
    cmds = []
    for dataset in datasets:
        for mode in modes:
            for weight_decay in weight_decays:
                template_g = py3 + \
                             get_project_path() + f"/MLC/main.py" + \
                             f" --dataset={dataset} " \
                             f" --mode={mode}" \
                             f" --model={model}" \
                             f" --lr={lr}" \
                             f" --max_epoch={epoch}" \
                             f" --n_hidden={n_hidden}" \
                             f" --weight_decay={weight_decay}" \
                             f" --cuda={cuda}"

                cmds.append(template_g)

    mc = MyClock()

    for c in cmds:
        os.system(c)
        mc.cal_time()