import os
import argparse
import numpy as np
import torch
import logging.config

from torchvision import models
from torchvision.models import ResNet18_Weights

from utils.data import get_dataset, get_dataloader, get_unlearn_loader
from utils.backbone import get_model
from pretraining.trainer import train_and_save
from utils.utils import report_sample_by_class

from collections import OrderedDict

def seed_torch(seed):
    np.random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def arg_parse():
    parser = argparse.ArgumentParser("Boundary Unlearning")
    parser.add_argument('--rnd_seed', type=int, default=0, help='random seed')
    parser.add_argument('--data_name', type=str, default='cifar100', help='cifar10, cifar100, tinyimagenet, imagenet')
    parser.add_argument('--model_name', type=str, default='ResNet18', help='model name')
    parser.add_argument('--remain_batch_size', type=int, default=128, help='batch size')
    parser.add_argument('--forget_batch_size', type=int, default=128, help='batch size')
    parser.add_argument('--test_mode', type=str, default='class', choices=['sample', 'sub_class', 'class'], help='unlearning mode')
    parser.add_argument('--debug_mode', action='store_true', help='debug mode, it takes more time but with more detailed information')
    parser.add_argument('--train_epochs', type=int, default=300, help='train epochs')

    # class unlearning, test_mode=class
    parser.add_argument('--class_idx', type=int, default=4, help='class index to unlearn')
    parser.add_argument('--class_idx_unlearn', type=int, default=1, help='class index to unlearn')

    # sub_class unlearning, test_mode=sub_class
    # vehicle2, rocket
    # vegetables, mushroom
    # people, baby
    # electrical devices, lamp
    # natural scene, sea
    # following [https://arxiv.org/pdf/2205.08096.pdf], incompotent teacher.
    parser.add_argument('--sub_class_name', type=str, nargs="+", default="", help='sub class name to unlearn')
    
    # sample unlearning, test_mode=sample
    parser.add_argument('--sample_unlearn_per_class', type=int, default=50, help='number of unlearning samples per class')

    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = arg_parse()
    print(args)
    seed_torch(args.rnd_seed)
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if args.test_mode == "class":
        model_name=f'{args.model_name}_{args.data_name}_{args.test_mode}_{args.class_idx}_{args.class_idx_unlearn}'
    elif args.test_mode == "sample":
        model_name=f'{args.model_name}_{args.data_name}_{args.test_mode}_{args.sample_unlearn_per_class}'
    
    trainset, testset, trainset_test, num_cls = get_dataset(args)
    train_loader, test_loader, train_test_loader = get_dataloader(trainset, testset, trainset_test, args)
    train_forget_set, train_remain_set, test_forget_set, test_remain_set, train_forget_test_set, train_remain_test_set,\
    train_forget_loader, train_remain_loader, test_forget_loader, test_remain_loader, train_forget_test_loader, train_remain_test_loader,\
    train_adjacent_set, test_adjacent_set, train_adjacent_test_set, train_adjacent_loader, test_adjacent_loader, train_adjacent_test_loader = get_unlearn_loader(trainset, testset, trainset_test, args)
    
    print(f"train_forget: {report_sample_by_class(train_forget_loader)}, train_remain: {report_sample_by_class(train_remain_loader)}")
    print(f"test_forget: {report_sample_by_class(test_forget_loader)}, test_remain: {report_sample_by_class(test_remain_loader)}")

    # train_transform
    print(f"{train_loader.dataset.transform}, {train_remain_loader.dataset.transform}, {test_loader.dataset.transform}")

    if args.test_mode == "sub_class": args.num_classes = 20

    model = get_model(args.model_name, num_classes=args.num_classes).to(args.device)
    train_and_save(model, train_loader, test_loader, f"{args.model_name}_{args.data_name}", args, mode='ori')
    model = get_model(args.model_name, num_classes=args.num_classes).to(args.device)
    train_and_save(model, train_remain_loader, test_loader, model_name, args, mode='retrain')
