import argparse
import torch
torch.manual_seed(3407)
import numpy as np 

from copy import deepcopy

from utils import prep_data, train_retrain_classifier, evaluate_retrain_classifier
from config import get_RT_model_save_path, get_RT_model_acc_save_path, get_RT_model_y_pred_save_path, get_RT_model_running_time_save_path, get_RT_model_train_loss_save_path


if __name__ == "__main__":
    parser = argparse.ArgumentParser() 
    

    parser.add_argument('--device', default='cuda', type=str, help='device')
    parser.add_argument('--RT_training_epochs', default=100, type=int, help='model training epochs of retrained model')    
    parser.add_argument('--RT_train_times', default=2, type=int, help='how many times to retrain the model -- take the avg acc')
    parser.add_argument('--where_to_unl', default='nose', choices=['nose', 'eye', 'noseeye'], type=str,
                        help='where to unlearn')
    parser.add_argument('--label', default='Big_Nose', choices=['Male', 'Big_Nose', 'Pointy_Nose', 'Eyeglasses', 'Narrow_Eyes'], type=str,
                        help='the label for the cv dataset')
    
    
    args = parser.parse_args()
    
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    RT_training_epochs = args.RT_training_epochs
    RT_train_times = args.RT_train_times
    where_to_unl = args.where_to_unl
    label = args.label 
    
    print(device)
    
    
    train_attr_dict, test_attr_dict, _, _ = prep_data(label=label, where_to_unl=where_to_unl, retrain_or_shuffle='retrain')
    
    
    RT_model_lst, RT_training_time_lst, RT_train_loss_lst = [], [], []
    RT_acc_lst, RT_y_pred_lst = [], []
    for _ in range(RT_train_times):
        RT_backbone, RT_classifier, RT_training_time, RT_train_loss = train_retrain_classifier(train_attr_dict=train_attr_dict, where_to_unl=where_to_unl, device=device, epochs=RT_training_epochs)
        RT_acc, RT_y_pred, _ = evaluate_retrain_classifier(backbone=RT_backbone, classifier=RT_classifier, test_attr_dict=test_attr_dict, where_to_unl=where_to_unl, device=device)
        
        RT_model_lst.append(RT_classifier)
        RT_training_time_lst.append(RT_training_time)
        RT_train_loss_lst.append(RT_train_loss)
        RT_acc_lst.append(RT_acc)
        RT_y_pred_lst.append(RT_y_pred)
    
    
    # save res 
    np.savetxt(get_RT_model_acc_save_path(RT_training_epochs=RT_training_epochs, where_to_unl=where_to_unl, label=label), [np.mean(RT_acc_lst), np.std(RT_acc_lst)])
    np.savetxt(get_RT_model_running_time_save_path(RT_training_epochs=RT_training_epochs, where_to_unl=where_to_unl, label=label), [np.mean(RT_training_time_lst), np.std(RT_training_time_lst)])
    
    for train_times_idx, (RT_model, RT_y_pred, RT_train_loss) in enumerate(zip(RT_model_lst, RT_y_pred_lst, RT_train_loss_lst)):   
        # train_times_idx = deepcopy(train_times_idx) + 1
        
             
        torch.save(RT_model.state_dict(), get_RT_model_save_path(RT_training_epochs=RT_training_epochs, RT_train_times_idx=train_times_idx, where_to_unl=where_to_unl, label=label))
        np.savetxt(get_RT_model_y_pred_save_path(RT_training_epochs=RT_training_epochs, RT_train_times_idx=train_times_idx, where_to_unl=where_to_unl, label=label), RT_y_pred)
        np.savetxt(get_RT_model_train_loss_save_path(RT_training_epochs=RT_training_epochs, RT_train_times_idx=train_times_idx, where_to_unl=where_to_unl, label=label), RT_train_loss)
    

    