import os
import argparse
import numpy as np
import torch
import logging.config

from utils.data import get_dataset, get_dataloader, get_unlearn_loader
from utils.backbone import get_model
from utils.method import run_method
from utils.eval import evaluate_summary

def seed_torch(seed):
    np.random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def log(args):
    logging.config.fileConfig('./utils/logging.conf')
    args.logger = logging.getLogger()
    os.makedirs(f'./logs/{args.data_name}', exist_ok=True)
    if args.test_mode == 'class': args.method_model_name = f'{args.method}_{args.data_name}_{args.test_mode}_{args.class_idx}_{args.remain_epochs}_{args.forget_epochs}_{args.note}'
    elif args.test_mode == 'sub_class': args.method_model_name = f'{args.method}_{args.data_name}_{args.test_mode}_{args.sub_class_name}_{args.remain_epochs}_{args.forget_epochs}_{args.note}'
    elif args.test_mode == 'sample': args.method_model_name = f'{args.method}_{args.data_name}_{args.test_mode}_{args.sample_unlearn_per_class}_{args.remain_epochs}_{args.forget_epochs}_{args.note}'

    os.makedirs(f'./logs/{args.data_name}/{args.method_model_name}', exist_ok=True)
    fileHandler = logging.FileHandler(f'./logs/{args.data_name}/{args.method_model_name}/seed{args.rnd_seed}.log', mode='w')
    args.logger.addHandler(fileHandler)

def arg_parse():
    parser = argparse.ArgumentParser("Boundary Unlearning")
    parser.add_argument('--rnd_seed', type=int, default=0, help='random seed') # 0, 1, 2
    parser.add_argument('--method', type=str, default='ft', help='unlearning method')
    parser.add_argument('--data_name', type=str, default='cifar10', help='dataset, cifar10, cifar100, imagenet, cars or flowers')
    parser.add_argument('--model_name', type=str, default='ResNet18', help='model name')
    parser.add_argument('--remain_batch_size', type=int, default=32, help='remain batch size')
    parser.add_argument('--forget_batch_size', type=int, default=32, help='forget batch size')
    parser.add_argument('--test_mode', type=str, default='sample', choices=['sample', 'sub_class', 'class'], help='unlearning mode')
    parser.add_argument('--remain_epochs', type=int, help='remain_epochs', required=True) # zero is possible
    parser.add_argument('--forget_epochs', type=int, help='forget_epochs', required=True) # zero is possible
    parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
    parser.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'], help='optimizer')
    parser.add_argument('--save_result_model', action='store_true', help="save result model")
    parser.add_argument('--note', type=str, default='', help='note')
    parser.add_argument('--retain_ratio', type=float, default=1, help='retain ratio')
    parser.add_argument('--num_classes', type=int, default=10, help='number of classes')
    parser.add_argument('--model_seed', type=int, default=0, help='model seed')
    parser.add_argument('--exp_name', type=str, help="exp name", required=True)
    parser.add_argument('--eval_mode', action='store_true', help="only for additional evaluation")

    # model selection
    parser.add_argument('--unlearn_aug', action='store_true', help="unlearn with data augmentation")

    # class unlearning, test_mode=class
    parser.add_argument('--class_idx', type=int, default=0, help='class index to unlearn')
    parser.add_argument('--class_idx_unlearn', type=int, default=1, help='class index to unlearn')
    parser.add_argument('--sub_class_name', type=str, nargs="+", default="", help='sub class name to unlearn')
    
    # sample unlearning, test_mode=sample
    parser.add_argument('--sample_unlearn_per_class', type=int, default=100, help='number of unlearning samples per class')
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = arg_parse()
    seed_torch(args.rnd_seed+2)
    print(args.sub_class_name)
    print(type(args.sub_class_name))
    log(args)
    args.logger.info(args)
    args.logger.info(f'Model Selection: unlearn_aug={args.unlearn_aug}')
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.logger.info(f"Device: {args.device}")
        
    trainset, testset, trainset_test, num_cls = get_dataset(args)
    args.logger.info("Dataset")
    train_loader, test_loader, train_test_loader = get_dataloader(trainset, testset, trainset_test, args)
    args.logger.info("Dataloader")

    train_forget_set, train_remain_set, test_forget_set, test_remain_set, train_forget_test_set, train_remain_test_set,\
        train_forget_loader, train_remain_loader, test_forget_loader, test_remain_loader, train_forget_test_loader, train_remain_test_loader,\
        train_adjacent_set, test_adjacent_set, train_adjacent_test_set, train_adjacent_loader, test_adjacent_loader, train_adjacent_test_loader = get_unlearn_loader(trainset, testset, trainset_test, args)
    args.logger.info("Unlearn Dataloader")

    model_name = f"{args.model_name}_{args.data_name}_{args.test_mode}"
    if args.test_mode == 'class': model_name += f"_{args.class_idx}_{args.class_idx_unlearn}"
    elif args.test_mode == 'sub_class': model_name += f"_{args.sub_class_name}"
    elif args.test_mode == 'sample': model_name += f"_{args.sample_unlearn_per_class}"

    print(model_name)

    loaders = dict(trainset=trainset, testset=testset, train_forget_set=train_forget_set, train_remain_set=train_remain_set, test_forget_set=test_forget_set, test_remain_set=test_remain_set, \
        train_loader=train_loader, test_loader=test_loader, train_forget_loader=train_forget_loader, train_remain_loader=train_remain_loader, test_forget_loader=test_forget_loader, test_remain_loader=test_remain_loader, \
        trainset_test=trainset_test, train_forget_test_set=train_forget_test_set, train_remain_test_set=train_remain_test_set, train_forget_test_loader=train_forget_test_loader, train_remain_test_loader=train_remain_test_loader, train_test_loader=train_test_loader, \
        train_adjacent_set=train_adjacent_set, test_adjacent_set=test_adjacent_set, train_adjacent_test_set=train_adjacent_test_set, train_adjacent_loader=train_adjacent_loader, test_adjacent_loader=test_adjacent_loader, train_adjacent_test_loader=train_adjacent_test_loader \
    )
    
    if args.data_name == 'imagenet':
        model = get_model(args.model_name+'_imagenet', num_classes=args.num_classes, ckpt_path=f'./checkpoints/{args.model_name}_{args.data_name}_ori{args.model_seed}.pth').to(args.device)
        retrain_model = get_model(args.model_name+'_imagenet', num_classes=args.num_classes, ckpt_path=f'./checkpoints/{model_name}_retrain{args.model_seed}.pth').to(args.device)
    else:
        model = get_model(args.model_name, num_classes=args.num_classes, ckpt_path=f'./checkpoints/{args.model_name}_{args.data_name}_ori{args.model_seed}.pth').to(args.device)
        retrain_model = get_model(args.model_name, num_classes=args.num_classes, ckpt_path=f'./checkpoints/{model_name}_retrain{args.model_seed}.pth').to(args.device)    
    
    if not args.eval_mode:
    # only use CIFAR100 for sub_class
        result_model, statistics = run_method(model, retrain_model, loaders, args)

        if args.save_result_model:
            if not os.path.exists(f'./final_checkpoints/{args.data_name}/{args.test_mode}/'):
                os.makedirs(f'./final_checkpoints/{args.data_name}/{args.test_mode}/')
            torch.save(result_model.state_dict(), f'./final_checkpoints/{args.data_name}/{args.test_mode}/{args.exp_name}{args.rnd_seed}.pth')
        
        evaluate_summary(model, retrain_model, result_model, statistics, loaders, args)
    else:
        if args.data_name == 'imagenet':
            result_model = get_model(args.model_name+'_imagenet', num_classes=args.num_classes, ckpt_path=f'./final_checkpoints/{args.data_name}/{args.test_mode}/{args.exp_name}{args.rnd_seed}.pth').to(args.device)
        else:
            result_model = get_model(args.model_name, num_classes=args.num_classes, ckpt_path=f'./final_checkpoints/{args.data_name}/{args.test_mode}/{args.exp_name}{args.rnd_seed}.pth').to(args.device)
        evaluate_summary(model, retrain_model, result_model, None, loaders, args)