import argparse
import torch
torch.manual_seed(3407)
import numpy as np 

from copy import deepcopy

from config import get_ori_model_save_path
from config import get_BL2_model_save_path, get_BL2_model_acc_save_path, get_BL2_model_y_pred_save_path, get_BL2_model_running_time_save_path, get_BL2_model_train_loss_save_path
from utils import ViTBackBone, ViTGenderClassifier
from utils import prep_data, train_BL2_classifier, evaluate_BL2_classifier


if __name__ == "__main__":
    parser = argparse.ArgumentParser() 
    
    parser.add_argument('--device', default='cuda', type=str, help='device')
    parser.add_argument('--ori_training_epochs', default=100, type=int, help='original model training epochs')    
    parser.add_argument('--unlearning_epochs', default=2, type=int, help='unlearning epochs for baseline 2')
    parser.add_argument('--unlearn_times', default=2, type=int, help='unlearn times')
    parser.add_argument('--where_to_unl', default='nose', choices=['nose', 'eye', 'noseeye'], type=str,
                        help='where to unlearn')
    parser.add_argument('--label', default='Big_Nose', choices=['Male', 'Big_Nose', 'Pointy_Nose', 'Eyeglasses', 'Narrow_Eyes'], type=str,
                        help='the label for the cv dataset')
    
    
    args = parser.parse_args()
    
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    ori_training_epochs = args.ori_training_epochs
    BL2_unlearning_epochs = args.unlearning_epochs
    unlearn_times = args.unlearn_times
    where_to_unl = args.where_to_unl 
    label = args.label
    
    
    train_attr_dict, test_attr_dict, train_area_dict, test_area_dict = prep_data(where_to_unl=where_to_unl, retrain_or_shuffle='shuffle', label=label)
    
    
    
    # load original model
    ori_backbone = ViTBackBone().to(device)
    embedding_dim = ori_backbone.backbone.embed_dim
    ori_classifier = ViTGenderClassifier(embedding_dim=embedding_dim).to(device)
    ori_classifier.load_state_dict(torch.load(get_ori_model_save_path(ori_training_epochs=ori_training_epochs, label=label), map_location=device))
    
    
    # prep data 
    BL2_model_lst, BL2_training_time_lst, BL2_train_loss_lst = [], [], []
    BL2_acc_lst, BL2_y_pred_lst = [], []
    for _ in range(unlearn_times):
        BL2_classifier = deepcopy(ori_classifier)
        # baseline 2
        backbone, BL2_classifier, BL2_training_time, BL2_train_loss = train_BL2_classifier(backbone=ori_backbone, classifier=BL2_classifier, train_attr_dict=train_attr_dict, where_to_unl=where_to_unl, epochs=BL2_unlearning_epochs, device=device)
        # evaluate baseline 2
        BL2_acc, BL2_y_pred, _ = evaluate_BL2_classifier(backbone=backbone, classifier=BL2_classifier, test_attr_dict=test_attr_dict, where_to_unl=where_to_unl, device=device)
        
        BL2_model_lst.append(BL2_classifier)
        BL2_training_time_lst.append(BL2_training_time)
        BL2_train_loss_lst.append(BL2_train_loss)
        BL2_acc_lst.append(BL2_acc)
        BL2_y_pred_lst.append(BL2_y_pred)



    
    # save res
    np.savetxt(get_BL2_model_acc_save_path(ori_training_epochs=ori_training_epochs, BL2_unlearning_epochs=BL2_unlearning_epochs, where_to_unl=where_to_unl, label=label), [np.mean(BL2_acc_lst), np.std(BL2_acc_lst)])
    np.savetxt(get_BL2_model_running_time_save_path(ori_training_epochs=ori_training_epochs, BL2_unlearning_epochs=BL2_unlearning_epochs, where_to_unl=where_to_unl, label=label), [np.mean(BL2_training_time_lst), np.std(BL2_training_time_lst)])    
    
    # save model 
    for unlearn_times_idx, (BL2_model, BL2_y_pred, BL2_train_loss) in enumerate(zip(BL2_model_lst, BL2_y_pred_lst, BL2_train_loss_lst)):
        unlearn_times_idx = deepcopy(unlearn_times_idx+1)
        
        torch.save(BL2_model.state_dict(), get_BL2_model_save_path(ori_training_epochs=ori_training_epochs, BL2_unlearning_epochs=BL2_unlearning_epochs, unlearn_times_idx=unlearn_times_idx, where_to_unl=where_to_unl, label=label))
        np.savetxt(get_BL2_model_y_pred_save_path(ori_training_epochs=ori_training_epochs, BL2_unlearning_epochs=BL2_unlearning_epochs, unlearn_times_idx=unlearn_times_idx, where_to_unl=where_to_unl, label=label), BL2_y_pred)
        np.savetxt(get_BL2_model_train_loss_save_path(ori_training_epochs=ori_training_epochs, BL2_unlearning_epochs=BL2_unlearning_epochs, unlearn_times_idx=unlearn_times_idx, where_to_unl=where_to_unl, label=label), BL2_train_loss)
    

