import argparse
import torch
torch.manual_seed(3407)
import numpy as np 
from copy import deepcopy


from config import get_ori_model_save_path
from config import get_our_model_save_path, get_our_model_acc_save_path, get_our_model_y_pred_save_path, get_our_model_running_time_save_path, get_our_model_train_loss_save_path
from utils import ViTBackBone, ViTGenderClassifier
from utils import prep_data
from utils import train_shuffle_classifier, evaluate_shuffle_classifier



if __name__ == "__main__":
    parser = argparse.ArgumentParser() 
    
    
    parser.add_argument('--device', default='cuda', type=str, help='device')
    parser.add_argument('--ori_training_epochs', default=100, type=int, help='original model training epochs')    
    parser.add_argument('--unlearning_epochs', default=200, type=int, help='our unlearning epochs')
    parser.add_argument('--unlearn_times', default=2, type=int, help='unlearn times')
    parser.add_argument('--where_to_unl', default='nose', choices=['nose', 'eye', 'noseeye'], type=str,
                        help='where to unlearn')
    parser.add_argument('--label', default='Big_Nose', choices=['Male', 'Big_Nose', 'Pointy_Nose', 'Eyeglasses', 'Narrow_Eyes'], type=str,
                        help='the label for the cv dataset')
    
    args = parser.parse_args()
    
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    ori_training_epochs = args.ori_training_epochs
    unlearning_epochs = args.unlearning_epochs
    unlearn_times = args.unlearn_times
    where_to_unl = args.where_to_unl 
    label = args.label 
    
    train_attr_dict, test_attr_dict, train_area_dict, test_area_dict = prep_data(label=label, where_to_unl=where_to_unl, retrain_or_shuffle='shuffle')

    
    # load original model
    UL_model_lst, UL_training_time_lst, UL_train_loss_lst = [], [], []
    UL_acc_lst, UL_y_pred_lst = [], []
    for _ in range(unlearn_times):
        
        UL_backbone = ViTBackBone().to(device)
        embedding_dim = UL_backbone.backbone.embed_dim
        UL_classifier = ViTGenderClassifier(embedding_dim=embedding_dim).to(device)
        

        UL_classifier.load_state_dict(torch.load(get_ori_model_save_path(ori_training_epochs=ori_training_epochs, label=label), map_location=device))    
        
        
        # unlearn original model
        UL_backbone, UL_classifier, UL_training_time, task_loss_lst = train_shuffle_classifier(UL_backbone=UL_backbone, UL_classifier=UL_classifier, train_attr_dict=train_attr_dict, train_area_dict=train_area_dict, where_to_unl=where_to_unl, device=device, epochs=unlearning_epochs)
        # evaluate unlearned model 
        UL_acc, UL_y_pred, _ = evaluate_shuffle_classifier(backbone=UL_backbone, classifier=UL_classifier, test_attr_dict=test_attr_dict, test_area_dict=test_area_dict, where_to_unl=where_to_unl, device=device)
        
        
        UL_model_lst.append(UL_classifier)
        UL_training_time_lst.append(UL_training_time)
        UL_train_loss_lst.append(task_loss_lst)
        UL_acc_lst.append(UL_acc)
        UL_y_pred_lst.append(UL_y_pred)
        
        
    
    # save res
    np.savetxt(get_our_model_acc_save_path(ori_training_epochs=ori_training_epochs, unlearning_epochs=unlearning_epochs, where_to_unl=where_to_unl, label=label), [np.mean(UL_acc_lst), np.std(UL_acc_lst)])
    np.savetxt(get_our_model_running_time_save_path(ori_training_epochs=ori_training_epochs, unlearning_epochs=unlearning_epochs, where_to_unl=where_to_unl, label=label), [np.mean(UL_training_time_lst), np.std(UL_training_time_lst)])

    # save model
    for unlearn_times_idx, (UL_model, UL_y_pred, UL_train_loss) in enumerate(zip(UL_model_lst, UL_y_pred_lst, UL_train_loss_lst)):
        
        # unlearn_times_idx = deepcopy(unlearn_times_idx) + 1
        
        torch.save(UL_model.state_dict(), get_our_model_save_path(ori_training_epochs=ori_training_epochs, unlearning_epochs=unlearning_epochs, unlearn_times_idx=unlearn_times_idx, where_to_unl=where_to_unl, label=label))
        np.savetxt(get_our_model_y_pred_save_path(ori_training_epochs=ori_training_epochs, unlearning_epochs=unlearning_epochs, unlearn_times_idx=unlearn_times_idx, where_to_unl=where_to_unl, label=label), UL_y_pred)
        np.savetxt(get_our_model_train_loss_save_path(ori_training_epochs=ori_training_epochs, unlearning_epochs=unlearning_epochs, unlearn_times_idx=unlearn_times_idx, where_to_unl=where_to_unl, label=label), UL_train_loss)
    
    
    