import argparse
import torch
torch.manual_seed(3407)
import numpy as np 


from config import get_ori_model_save_path, get_ori_model_acc_save_path, get_ori_model_y_pred_save_path, get_ori_model_running_time_save_path, get_ori_model_train_loss_save_path
from utils import train_ori_classifier, evaluate_ori_classifier
from utils import prep_data





if __name__ == "__main__":
    parser = argparse.ArgumentParser() 
    
    parser.add_argument('--device', default='cuda', type=str, help='device')
    parser.add_argument('--ori_training_epochs', default=100, type=int, help='original model training epochs')  
    parser.add_argument('--label', default='Male', choices=['Male', 'Big_Nose', 'Pointy_Nose', 'Eyeglasses', 'Narrow_Eyes'], type=str,
                        help='the label for the cv dataset')
    
    args = parser.parse_args()
    
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    ori_training_epochs = args.ori_training_epochs
    label = args.label 
    
    print(device)
    
    train_attr_dict, test_attr_dict, _, _ = prep_data(label=args.label)
    
    
    ori_model_backbone, ori_model_classifier, ori_training_time, ori_train_loss_lst = train_ori_classifier(train_attr_dict=train_attr_dict, device=device, epochs=ori_training_epochs)
    ori_acc, ori_y_pred, _ = evaluate_ori_classifier(backbone=ori_model_backbone, classifier=ori_model_classifier, test_attr_dict=test_attr_dict, device=device)
    


    # save model
    torch.save(ori_model_classifier.state_dict(), get_ori_model_save_path(ori_training_epochs=ori_training_epochs, label=label))
    
    # save res 
    np.savetxt(get_ori_model_acc_save_path(ori_training_epochs=ori_training_epochs, label=label), [ori_acc])
    
    # save predictions
    np.savetxt(get_ori_model_y_pred_save_path(ori_training_epochs=ori_training_epochs, label=label), ori_y_pred)
    
    # save running time    
    np.savetxt(get_ori_model_running_time_save_path(ori_training_epochs=ori_training_epochs, label=label), [ori_training_time])
    
    # save training loss 
    np.savetxt(get_ori_model_train_loss_save_path(ori_training_epochs=ori_training_epochs, label=label), ori_train_loss_lst)
    
    
    
    
    
    
    
    
    
    
    