import argparse
import torch
torch.manual_seed(3407)
import numpy as np 


from config import get_ori_model_save_path, get_BL3DecoderMIhx_model_save_path, get_BL3MIhy_model_save_path, get_BL3MIhz_model_save_path, get_BL3_prep_running_time_save_path
from utils import prep_data, get_h_for_BL3, train_BL3DecoderMIhx, train_BL3MIhy, train_BL3MIhz
from utils import ViTBackBone, ViTGenderClassifier

if __name__ == "__main__":
    parser = argparse.ArgumentParser() 
    
    parser.add_argument('--device', default='cuda', type=str, help='device')
    parser.add_argument('--ori_BL3_training_epochs', default=100, type=int, help='original model training epochs')    
    parser.add_argument('--BL3DecoderMIhx_training_epochs', default=10, type=int, help='training epochs of baseline 3 decoder for I(h,x), x is remaining features')
    parser.add_argument('--BL3MIhy_training_epochs', default=10, type=int, help='training epochs of baseline 3 for I(h,y), y is label')
    parser.add_argument('--BL3MIhz_training_epochs', default=10, type=int, help='training epochs of baseline 3 for I(h,z), z is unlearned feature')
    parser.add_argument('--label', default='Male', choices=['Male', 'Big_Nose', 'Pointy_Nose', 'Eyeglasses', 'Narrow_Eyes'], type=str,
                        help='the label for the cv dataset')

    
    args = parser.parse_args()
    
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    ori_BL3_training_epochs = args.ori_BL3_training_epochs
    BL3DecoderMIhx_training_epochs = args.BL3DecoderMIhx_training_epochs
    BL3MIhy_training_epochs = args.BL3MIhy_training_epochs
    BL3MIhz_training_epochs = args.BL3MIhz_training_epochs
    label = args.label 
    
    train_attr_dict, test_attr_dict, _, _ = prep_data(label=args.label)
    
    
    vit_backbone = ViTBackBone().to(device)
    embedding_dim = vit_backbone.backbone.embed_dim
    vit_classifier = ViTGenderClassifier(embedding_dim=embedding_dim).to(device)
    vit_classifier.load_state_dict(torch.load(get_ori_model_save_path(ori_training_epochs=ori_BL3_training_epochs, label=label), map_location=device))    
    
    
    # get h     
    h_train = get_h_for_BL3(attr_dict=train_attr_dict, device=device, backbone=vit_backbone)
    
    # train BL3DecoderMIhx
    BL3DecoderMIhx_model, BL3DecoderMIhx_training_time = train_BL3DecoderMIhx(h_train=h_train, attr_dict=train_attr_dict, device=device, backbone=vit_backbone, epochs=BL3DecoderMIhx_training_epochs)
    # train BL3MIhy
    BL3MIhy_model, BL3MIhy_training_time = train_BL3MIhy(h_train=h_train, attr_dict=train_attr_dict, device=device, backbone=vit_backbone, epochs=BL3MIhy_training_epochs)
    # train BL3MIhz
    BL3MIhz_model, BL3MIhz_training_time = train_BL3MIhz(h_train=h_train, attr_dict=train_attr_dict, device=device, backbone=vit_backbone, epochs=BL3MIhz_training_epochs)
    
    overall_training_time = BL3DecoderMIhx_training_time + BL3MIhy_training_time + BL3MIhz_training_time
    

    torch.save(BL3DecoderMIhx_model.state_dict(), get_BL3DecoderMIhx_model_save_path(ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, label=label))
    torch.save(BL3MIhy_model.state_dict(), get_BL3MIhy_model_save_path(ori_BL3_training_epochs=ori_BL3_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, label=label))
    torch.save(BL3MIhz_model.state_dict(), get_BL3MIhz_model_save_path(ori_BL3_training_epochs=ori_BL3_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, label=label))
    
    
    np.savetxt(get_BL3_prep_running_time_save_path(ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, label=label), [overall_training_time])
