from utils import *
from algorithms import *
import argparse


parser = argparse.ArgumentParser()
parser.add_argument("--mnist_root", type=str, default='./data/mnist/')
parser.add_argument("--cifar_root", type=str, default='./data/cifar/')
parser.add_argument("--fashm_root", type=str, default='./data/FashionMNIST/')

parser.add_argument("--lr_max", type=float, default=0.001)
parser.add_argument("--dataset", type=str, default='mc')
parser.add_argument("--heads", type=int, default=2)
parser.add_argument("--num_runs", type=int, default=3)
parser.add_argument("--max_epoch", type=int, default=70)
parser.add_argument("--alpha", type=float, default=1.0) 
parser.add_argument("--method", type=str, default="D-BAT")
parser.add_argument("--degree_of_balance", type=float, default=None)  # [0, 1], where 1 is balanced
parser.add_argument("--nosave", action="store_true", default=False)


args = parser.parse_args()

exp_name = f"{args.dataset}_{args.method}_alpha{args.alpha}_balance{args.degree_of_balance}"

if args.dataset == 'mc':
    # DivDis 1.0/50.0, D-BAT 0.1/5.0
    train_dl, valid_dl, test_dl, valid_rt_dl, test_rt_dl, valid_rb_dl, test_rb_dl, perturb_dl, perturb_dl_test = get_mc_dataset(args.mnist_root, args.cifar_root, args.degree_of_balance)
    model_fn = LeNet_MC
    use_scheduler=True
    args.lr_max = 0.001
elif args.dataset == "mf":
    train_dl, valid_dl, test_dl, valid_rt_dl, test_rt_dl, valid_rb_dl, test_rb_dl, perturb_dl, perturb_dl_test = get_mf_dataset(args.mnist_root, args.fashm_root, args.degree_of_balance)
    model_fn = LeNet_MF
    use_scheduler=True
    args.lr_max = 0.001
else:
    NotImplementedError

all_stats = {}
for i in range(args.num_runs):
    if args.method == "DivDis":
        stats = simultaneous_train_divdis(2, train_dl, valid_dl, valid_rt_dl, valid_rb_dl, test_dl, test_rt_dl, test_rb_dl, perturb_dl, perturb_dl_test, model_fn, alpha=args.alpha, max_epoch=args.max_epoch, lr_max=args.lr_max, use_scheduler=use_scheduler, opt="Adam")
    elif args.method == "D-BAT":
        stats = sequential_train_dbat(2, train_dl, valid_dl, valid_rt_dl, valid_rb_dl, test_dl, test_rt_dl, test_rb_dl, perturb_dl, perturb_dl_test, model_fn, alpha=args.alpha, max_epoch=args.max_epoch, lr_max=args.lr_max, use_scheduler=use_scheduler, opt="Adam")
    elif args.method == "DivDis-Seq":
        stats = sequential_train_divdis(2, train_dl, valid_dl, valid_rt_dl, valid_rb_dl, test_dl, test_rt_dl, test_rb_dl, perturb_dl, perturb_dl_test, model_fn, alpha=args.alpha, max_epoch=args.max_epoch, lr_max=args.lr_max, use_scheduler=use_scheduler, opt="Adam")
    else:
        NotImplementedError

    all_stats[f"run_{i+1}"] = stats
    # print_stats(stats)

if not args.nosave:
    with open(f"./results_mcf_v7_grid/{exp_name}", "w") as fp:
        json.dump(all_stats,fp)