import argparse
import json
import os
import sys

import torch
from cl_pipleine import ContinualLearningPipeline
from MERS.mers_utils.utils import init_exp_dir, init_exp_folder

def add_default_args(parser):
    parser.add_argument('--cuda_id', dest='cuda_id', required=False, type=int, default=0)
    parser.add_argument('--inc_model', dest='inc_model', help='The model to train with (etc. resnet18)',
                        choices=['resnet18', 'arch_craft', 'slim_resnet18'],
                        required=False, type=str, default='resnet18')
    parser.add_argument('--num_epochs', dest='num_epochs', help='Number of epochs', required=False, type=int,
                        default=100)
    parser.add_argument('--batch_size', dest='batch_size', help='Batch size', required=False, type=int,
                        default=128)
    parser.add_argument('--momentum', dest='momentum', required=False, type=float, default=0.9)
    parser.add_argument('--weight_decay', dest='weight_decay', required=False, type=float, default=0.0002)
    parser.add_argument('--lr', dest='lr', required=False, type=float, default=0.1)
    parser.add_argument('--seed', dest='seed', required=False, type=int, default=None,
                        help='Random seed for classes order. If None- class order is 0,1,2,...')
    parser.add_argument('--increase_factor', dest='increase_factor', required=False, type=float,
                        default=1, help='Increase factor for the delta of the probcover algorithm only')
    return parser


def get_parser():
    parser = argparse.ArgumentParser(description='Class Incremental Learning')
    parser.add_argument('--dataset', dest='dataset', help='Dataset name', required=True, type=str,
                        choices=['cifar10', 'cifar100', 'tinyimg', 'cub200'])
    parser.add_argument('--num_experiences', dest='num_experiences', help='Number of experiences/tasks',
                        required=True, type=int)
    parser.add_argument('--algorithm', dest='algorithm', help='Name of the algorithm to use',
                        choices=['flashback_er_ace','acr_er_ace','feature_replay','budgeted_cl','er_ace', 'er', 'gdumb','contrastive', 'finetune', 'naive_replay', 'der_pp', 'mir', 'xder'],
                        required=True, type=str)
    parser.add_argument('--sel_strategy', dest='sel_strategy', help='Selection strategy', required=False,
                        choices=['herding',"budget",'gss', 'teal', 'rm', 'centered', 'random', 'probcover', 'max_herding'],
                        type=str, default=None)
    parser.add_argument('--teal_type', dest='teal_type',
                        help="If sel_strategy is 'teal', the type of TEAL to use",
                        choices=['one_time', 'log_iterative'],
                        required=False, type=str, default='one_time')
    parser.add_argument('--buffer', dest='buffer', help='buffer size', required=True, type=int)
    parser.add_argument('--debug', dest='debug', help='Debug mode', required=False, action='store_true',
                        default=False)
    parser.add_argument('--exp_name', dest='exp_name', help='Debug mode', required=False, type=str,
                        default="final")
    parser.add_argument('--delta', dest='delta', help='delta for probcover algorithm', required=False,type=float,
                        default=0.65)
    parser.add_argument('--integrated_features', dest='integrated_features', help='integrated_features', required=False
                        ,default=False)
    parser.add_argument('--concatenated', dest='concatenated', help='concatenated', required=False
                        ,default=False)
    parser.add_argument('--alpha', dest='alpha', help='alpha for probcover algorithm', required=False, type=float,
                        default=None)
    parser.add_argument('--features_type', dest='features_type', help='feature_type: model based or other self-supervised representations', required=False,
                        choices=['model_based', 'dino', 'simclr', 'vicreg'],
                        default='model_based')
    parser.add_argument('--batch_id', dest='batch_id',
                        required=True)
    parser.add_argument("--order", dest='order', help='Order of the classes in the dataset',
                        required=False, type=int, default=None)
    parser.add_argument('--nvidia', dest='nvidia', help='Use nvidia dino',
                        action='store_true', default=False)
    parser.add_argument('--weight_method', dest='weight_method', help='Weight method for Probcover',
                        required=False, type=str, default='')
    parser.add_argument("--delta_mb", dest='delta_mb', required=False, type=str, default='1')
    parser.add_argument("--delta_ss", dest='delta_ss', required=False, type=str, default='k')
    parser.add_argument("--sigma_mb", dest='sigma_mb', required=False, type=str, default='1nn')
    parser.add_argument("--sigma_ss", dest='sigma_ss', required=False, type=str, default='knn')

    parser = add_default_args(parser)
    return parser


if __name__ == '__main__':

    args = get_parser().parse_args()
    if args.algorithm == 'er_ace':
        args.lr = 0.01
        args.batch_size = 10
    elif args.algorithm == 'mir':
        args.batch_size = 10
        args.lr=0.001
        args.num_epochs=100
        if args.dataset =='tinyimg':
            args.num_epochs = 15
    if args.dataset == 'cub200':
        args.num_epochs = 30
        args.batch_size = 16
    if args.debug:
        print('Debug mode')
        args.num_epochs = 1
    if args.algorithm =='er':
        args.batch_size=10

    print(vars(args))

    exp_dir = init_exp_dir(args)
    print(f"Experiment directory: {exp_dir}")
    device = torch.device("cuda", args.cuda_id) if torch.cuda.is_available() else "cpu"


    pipeline = ContinualLearningPipeline(args, device, exp_dir)
    exp_folder = init_exp_folder(exp_dir, pipeline.batch_id)

    exp_acc_dict = pipeline.train(exp_folder)

    # Create a JSON-serializable version of args
    args_dict = vars(args).copy()
    # Convert device object to string for JSON serialization
    if 'device' in args_dict:
        args_dict['device'] = str(args_dict['device'])
    
    with open(f'{exp_folder}/args.txt', 'w') as f:
        json.dump(args_dict, f, indent=2)
    torch.save({'exp_acc_dict': exp_acc_dict, 'args': vars(args)}, f"{exp_folder}/results.pyth")
    pipeline.run_and_save_results(args.seed, exp_acc_dict)

