import argparse
import torch
torch.manual_seed(3407)
import numpy as np 

from copy import deepcopy

from config import get_ori_model_save_path, get_BL3DecoderMIhx_model_save_path, get_BL3MIhy_model_save_path, get_BL3MIhz_model_save_path, get_BL3_prep_running_time_save_path
from config import get_BL3Classifier_model_save_path, get_BL3_model_acc_save_path, get_BL3_model_y_pred_save_path, get_BL3_model_running_time_save_path, get_BL3_model_train_task_loss_save_path, get_BL3_model_train_MI_loss_save_path
from utils import BL3DecoderMIhx, BL3MIhy, BL3MIhz, ViTBackBone, ViTGenderClassifier
from utils import train_unlearning_BL3, evaluate_BL3
from utils import prep_data


if __name__ == "__main__":
    parser = argparse.ArgumentParser() 
    
    parser.add_argument('--device', default='cuda', type=str, help='device')
    parser.add_argument('--ori_BL3_training_epochs', default=100, type=int, help='original model training epochs')    
    parser.add_argument('--BL3DecoderMIhx_training_epochs', default=10, type=int, help='training epochs of baseline 3 decoder for I(h,x), x is remaining features')
    parser.add_argument('--BL3MIhy_training_epochs', default=10, type=int, help='training epochs of baseline 3 for I(h,y), y is label')
    parser.add_argument('--BL3MIhz_training_epochs', default=10, type=int, help='training epochs of baseline 3 for I(h,z), z is unlearned feature')
    parser.add_argument('--BL3_unlearning_epochs', default=2, type=int, help='unlearning epochs of baseline 3')
    
    parser.add_argument('--BL3_lamda1', default=5, type=float, help='lamda1 of baseline 3')
    parser.add_argument('--BL3_lamda2', default=5, type=float, help='lamda2 of baseline 3')
    parser.add_argument('--BL3_lamda3', default=1, type=float, help='lamda3 of baseline 3')
    parser.add_argument('--unlearn_times', default=1, type=int, help='unlearn times')
    
    parser.add_argument('--where_to_unl', default='nose', choices=['nose', 'eye', 'noseeye'], type=str,
                        help='where to unlearn')
    parser.add_argument('--label', default='Big_Nose', choices=['Male', 'Big_Nose', 'Pointy_Nose', 'Eyeglasses', 'Narrow_Eyes'], type=str,
                        help='the label for the cv dataset')


    
    args = parser.parse_args()
    
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    ori_BL3_training_epochs = args.ori_BL3_training_epochs
    BL3DecoderMIhx_training_epochs = args.BL3DecoderMIhx_training_epochs
    BL3MIhy_training_epochs = args.BL3MIhy_training_epochs
    BL3MIhz_training_epochs = args.BL3MIhz_training_epochs
    BL3_unlearning_epochs = args.BL3_unlearning_epochs
    BL3_lamda1 = args.BL3_lamda1
    BL3_lamda2 = args.BL3_lamda2
    BL3_lamda3 = args.BL3_lamda3
    unlearn_times = args.unlearn_times 
    
    label = args.label 
    where_to_unl = args.where_to_unl 

    
    # get data 
    train_attr_dict, test_attr_dict, train_area_dict, test_area_dict = prep_data(label=label, where_to_unl=where_to_unl, retrain_or_shuffle='shuffle')
    
    
    
    # load original model
    BL3Classifier_model_lst, BL3_running_time_lst, BL3_task_loss_lst, BL3_MI_loss_lst = [], [], [], [], []
    BL3_acc_lst, BL3_y_pred_lst = [], []
    for _ in range(unlearn_times):
        vit_backbone = ViTBackBone().to(device)
        embedding_dim = vit_backbone.backbone.embed_dim
        classifier = ViTGenderClassifier(embedding_dim=embedding_dim).to(device)
        classifier.load_state_dict(torch.load(get_ori_model_save_path(ori_training_epochs=ori_BL3_training_epochs, label=label), map_location=device))    
        
        
        
        # load baseline 3 prep model
        BL3DecoderMIhx_model = BL3DecoderMIhx(embedding_dim=embedding_dim, output_dim=int(224*224)).to(device)
        BL3MIhy_model = BL3MIhy(embedding_dim=embedding_dim).to(device)
        BL3MIhz_model = BL3MIhz(embedding_dim=embedding_dim, output_dim=int(224*224)).to(device)
        BL3DecoderMIhx_model.load_state_dict(torch.load(get_BL3DecoderMIhx_model_save_path(ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, label=label), map_location=device))
        BL3MIhy_model.load_state_dict(torch.load(get_BL3MIhy_model_save_path(ori_BL3_training_epochs=ori_BL3_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, label=label), map_location=device))
        BL3MIhz_model.load_state_dict(torch.load(get_BL3MIhz_model_save_path(ori_BL3_training_epochs=ori_BL3_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, label=label), map_location=device))
        
        prep_running_time = np.loadtxt(get_BL3_prep_running_time_save_path(ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, label=label)).item()
        
        
        # baseline 3
        BL3Classifier_model, BL3_unlearning_time, BL3_task_loss, BL3_MI_loss = train_unlearning_BL3(
            BL3Classifier_model=BL3Classifier_model,
            BL3DecoderMIhx_model=BL3DecoderMIhx_model, BL3MIhy_model=BL3MIhy_model, BL3MIhz_model=BL3MIhz_model,
            attr_dict=train_attr_dict, device=device, 
            BL3_lamda1=BL3_lamda1, BL3_lamda2=BL3_lamda2, BL3_lamda3=BL3_lamda3,
            epochs=BL3_unlearning_epochs,
        )
        # evaluate baseline 3
        BL3_acc, BL3_y_pred, _ = evaluate_BL3(backbone=vit_backbone, BL3Classifier_model=BL3Classifier_model, attr_dict=test_attr_dict, device=device)
        BL3_running_time = prep_running_time + BL3_unlearning_time
        
        
        BL3Classifier_model_lst.append(BL3Classifier_model)
        BL3_running_time_lst.append(BL3_running_time)
        BL3_task_loss_lst.append(BL3_task_loss)
        BL3_MI_loss_lst.append(BL3_MI_loss)
        BL3_acc_lst.append(BL3_acc)
        BL3_y_pred_lst.append(BL3_y_pred)
        
    
    
    # save res 
    np.savetxt(get_BL3_model_acc_save_path(ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, BL3_unlearning_epochs=BL3_unlearning_epochs, BL3_lamda1=BL3_lamda1, BL3_lamda2=BL3_lamda2, BL3_lamda3=BL3_lamda3, label=label, where_to_unl=where_to_unl), [np.mean(BL3_acc_lst), np.std(BL3_acc_lst)])
    np.savetxt(get_BL3_model_running_time_save_path(ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, BL3_unlearning_epochs=BL3_unlearning_epochs, BL3_lamda1=BL3_lamda1, BL3_lamda2=BL3_lamda2, BL3_lamda3=BL3_lamda3, label=label, where_to_unl=where_to_unl), [np.mean(BL3_running_time_lst), np.std(BL3_running_time_lst)])
    
    # save model
    for unlearn_times_idx, (BL3Classifier_model, BL3_y_pred, BL3_task_loss, BL3_MI_loss) in enumerate(zip(BL3Classifier_model_lst, BL3_y_pred_lst, BL3_task_loss_lst, BL3_MI_loss_lst)):
        torch.save(BL3Classifier_model.state_dict(), get_BL3Classifier_model_save_path(ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, BL3_unlearning_epochs=BL3_unlearning_epochs, BL3_lamda1=BL3_lamda1, BL3_lamda2=BL3_lamda2, BL3_lamda3=BL3_lamda3, unlearn_times_idx=unlearn_times_idx, label=label, where_to_unl=where_to_unl))
        np.savetxt(get_BL3_model_y_pred_save_path(ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, BL3_unlearning_epochs=BL3_unlearning_epochs, BL3_lamda1=BL3_lamda1, BL3_lamda2=BL3_lamda2, BL3_lamda3=BL3_lamda3, unlearn_times_idx=unlearn_times_idx, label=label, where_to_unl=where_to_unl), BL3_y_pred)
        np.savetxt(get_BL3_model_train_task_loss_save_path(ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, BL3_unlearning_epochs=BL3_unlearning_epochs, BL3_lamda1=BL3_lamda1, BL3_lamda2=BL3_lamda2, BL3_lamda3=BL3_lamda3, unlearn_times_idx=unlearn_times_idx, label=label, where_to_unl=where_to_unl), BL3_task_loss)
        np.savetxt(get_BL3_model_train_MI_loss_save_path(ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, BL3_unlearning_epochs=BL3_unlearning_epochs, BL3_lamda1=BL3_lamda1, BL3_lamda2=BL3_lamda2, BL3_lamda3=BL3_lamda3, unlearn_times_idx=unlearn_times_idx, label=label, where_to_unl=where_to_unl), BL3_MI_loss)