# Description: Main file to run the verified continual learning experiments

import argparse
import torch
from conf_utils import *
from loguru import logger
from training import multiple_runs
# from training_vit import multiple_runs_vit
# torch.autograd.set_detect_anomaly(True)

def parse_args():
    parser = argparse.ArgumentParser(description='PyTorch Verified CL')
    parser.add_argument('--batch-size', type=int, default=32, metavar='n',
                        help='input batch size for training (default: 32)')
    parser.add_argument('--mini-batchsize', type=int, default=32, metavar='m',
                        help='Batch size for verification')
    parser.add_argument('--buffer-size', type=int, default=500, metavar='M',)
    parser.add_argument('--model', type=str, default="mlp_3", metavar='N', choices=["mlp_2", "mlp_3"],)
    parser.add_argument('--dataset', type=str, default="mnist", metavar='D', choices=["mnist", "core50","cifar10", "cifar100", "toy",
                     "ruarobot", "medical", "fmnist", "synbols", "permmnist", "tinyimg"],)
    parser.add_argument('--train-type', type=str, default="cerce", metavar='T', choices=["naive", "cerce", "er", "joint", "ewc", "lwf", "intercontinet", "agem"],)
    parser.add_argument('--buffer-select', type=str, choices=['correct','bound', 'rand'], default='rand')
    parser.add_argument('--test-batch-size', type=int, default=128, metavar='N',
                        help='input batch size for testing (default: 128)')

    parser.add_argument('--embed-img', action='store_true')
    parser.add_argument('--loss-fusion', action='store_true', default=False)
    parser.add_argument('--all-samples', action='store_true', default=False)
    parser.add_argument('--epochs', type=int, default=1, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--optimizer', type=str, default="sgd", metavar='O', choices=["adam", "sgd"],)
    parser.add_argument('--weight-decay', type=float, default=0.001, metavar='l',)
    parser.add_argument('--momentum', type=float, default=0, metavar='M',
                        help='SGD momentum (default: 0)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--num-runs', type=int, default=10, metavar='N', help='number of runs to average over (default: 10)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument('--data-path', type=str, default="./data/", metavar='N',
                        help='Path to store data')
    parser.add_argument('--class-inc', type=int, default=2, metavar='N',
                        help='Number of classes to increment')
    parser.add_argument('--n-tasks', type=int, default=5, metavar='N', help='Number of tasks to split dataset into')
    parser.add_argument('--c-slack', type=float, default=0.002, metavar='N',
                        help='Slack value for the optimization problem')
    parser.add_argument('--slack-norm', type=int, default=1, metavar='N',
                        help='Norm for the slack value')
    parser.add_argument('--hybrid-lagrange', action='store_true', default=False)
    parser.add_argument('--lirpa-method', type=str, default="crown-ibp", metavar='method',)
    parser.add_argument('--hdim', type=int, default=400, metavar='h',)
    parser.add_argument('--no-buffer', action='store_true', default=False)
    parser.add_argument('--grad-clip', action='store_true', default=False)
    parser.add_argument('--eval_current', action='store_true', default=False)
    parser.add_argument('--log-file', type=str, default="log.txt", metavar='p',)
    parser.add_argument('--wandb', action='store_true', default=False)
    parser.add_argument('--wandb-name', default="", type=str)
    parser.add_argument('--wandb-project', default="VCL-rebuttal-img", type=str)
    parser.add_argument('--track-buffer', action='store_true', default=False)

    ### CerCE
    parser.add_argument('--lam', type=float, default=0.1, metavar='l',)
    parser.add_argument('--gamma', type=float, default=0.001, metavar='e',)
    
    ### EwC
    parser.add_argument('--ewc-gamma', type=float, default=1, metavar='g',)
    parser.add_argument('--ewc-lambda', type=float, default=10, metavar='l',)

    ### LWF
    parser.add_argument('--alpha', type=float, default=0.5,
                            help='Penalty weight.')
    parser.add_argument('--softmax-temp', type=float, default=2,
                        help='Temperature of the softmax function.')

    ### DER
    parser.add_argument('--dark', action='store_true', help='Use this along with ER or CerCE in order to include DER loss')
    parser.add_argument('--alpha-d', type=float, default=0.1)
    parser.add_argument('--beta-d', type=float, default=0.3)

    ### Interval
    parser.add_argument('--center-lr', default=0.001, type=float, help='Expansion LR of intervalnet')
    parser.add_argument('--radii-lr', default=1, type=float, help='Radius LR of intervalnet')
    parser.add_argument('--max-radius', default=1, type=float, help='Initial radius')
    parser.add_argument('--contraction-epochs', default=10, type=int, help='Num contraction epochs')

    ### ViT
    parser.add_argument('--vit', action='store_true', default=False)
    parser.add_argument('--n-layers', default=12, type=int)
    parser.add_argument('--n-heads', default=8, type=int)
    parser.add_argument('--prompt-len', default=10, type=int)

    ### LPR
    parser.add_argument('--lpr', action='store_true', default=False)

    return parser.parse_args()



args = parse_args()
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
set_random_seed(args.seed)
if args.wandb:
    if args.wandb_name == "":
        args.name = f"{args.dataset}_{args.train_type}_{args.lr}"  
    else:
        args.name = args.wandb_name
    import wandb
    wandb.init(project=args.wandb_project,
               name=args.name,
               config=args)

logger.add(args.log_file)
logger.info(args)

# if args.vit:
#     multiple_runs_vit(args)
# else:
multiple_runs(args)

