import argparse
import torch
torch.manual_seed(3407)
import numpy as np 

from copy import deepcopy
from sklearn.model_selection import train_test_split

from utils import get_dataset_from_OPENML, preprocess_df
from config import NOMINIAL_COLS, NUMERICAL_COLS, UNL_FEATURE_NUM
from config import get_ori_BL3RepDetExtractormodel_save_path, get_ori_BL3Classifier_model_save_path, get_BL3DecoderMIhx_model_save_path, get_BL3MIhy_model_save_path, get_BL3MIhz_model_save_path, get_BL3_prep_running_time_save_path
from config import get_BL3RepDetExtractor_model_save_path, get_BL3Classifier_model_save_path, get_BL3_model_acc_save_path, get_BL3_model_y_pred_save_path, get_BL3_model_running_time_save_path, get_BL3_model_train_task_loss_save_path, get_BL3_model_train_MI_loss_save_path
from utils import BL3RepDetExtractor, BL3Classifier, BL3DecoderMIhx, BL3MIhy, BL3MIhz
from utils import train_unlearning_BL3, evaluate_BL3


if __name__ == "__main__":
    parser = argparse.ArgumentParser() 
    
    parser.add_argument('--dataset_name', default='ELECTRICITY', type=str, help='the name of the dataset')
    parser.add_argument('--device', default='cuda', type=str, help='device')
    parser.add_argument('--ori_BL3_training_epochs', default=1500, type=int, help='original model training epochs')    
    parser.add_argument('--BL3DecoderMIhx_training_epochs', default=2, type=int, help='training epochs of baseline 3 decoder for I(h,x), x is remaining features')
    parser.add_argument('--BL3MIhy_training_epochs', default=2, type=int, help='training epochs of baseline 3 for I(h,y), y is label')
    parser.add_argument('--BL3MIhz_training_epochs', default=2, type=int, help='training epochs of baseline 3 for I(h,z), z is unlearned feature')
    parser.add_argument('--BL3_unlearning_epochs', default=2, type=int, help='unlearning epochs of baseline 3')
    parser.add_argument('--BL3_lamda1', default=5, type=float, help='lamda1 of baseline 3')
    parser.add_argument('--BL3_lamda2', default=5, type=float, help='lamda2 of baseline 3')
    parser.add_argument('--BL3_lamda3', default=1, type=float, help='lamda3 of baseline 3')
    parser.add_argument('--unlearn_times', default=3, type=int, help='unlearn times')

    
    args = parser.parse_args()
    
    dataset_name = args.dataset_name
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    ori_BL3_training_epochs = args.ori_BL3_training_epochs
    BL3DecoderMIhx_training_epochs = args.BL3DecoderMIhx_training_epochs
    BL3MIhy_training_epochs = args.BL3MIhy_training_epochs
    BL3MIhz_training_epochs = args.BL3MIhz_training_epochs
    BL3_unlearning_epochs = args.BL3_unlearning_epochs
    BL3_lamda1 = args.BL3_lamda1
    BL3_lamda2 = args.BL3_lamda2
    BL3_lamda3 = args.BL3_lamda3
    unlearn_times = args.unlearn_times
    unl_feature_num = UNL_FEATURE_NUM[dataset_name]
    
    # get data 
    df = get_dataset_from_OPENML(dataset_name=dataset_name)
    output_dim = len(df[df.columns[-1]].unique())
    
    # split data
    nominial_cols, numerical_cols = NOMINIAL_COLS[dataset_name], NUMERICAL_COLS[dataset_name]
    X_processed, y = preprocess_df(df=df, numerical_cols=numerical_cols, nominial_cols=nominial_cols)
    if dataset_name == 'COMPASS': X_processed = X_processed.toarray()
    # split dataset for original training
    X_train, X_test, y_train, y_test = train_test_split(X_processed, y, test_size=0.2, random_state=42)
    # split dataset for retraining from scratch
    RT_X_processed = deepcopy(X_processed[:, unl_feature_num:])
    RT_X_train, RT_X_test, _, _ = train_test_split(RT_X_processed, y, test_size=0.2, random_state=42)
    
    
    
    # load original model
    BL3RepDetExtractor_model_lst, BL3Classifier_model_lst, BL3_running_time_lst, BL3_task_loss_lst, BL3_MI_loss_lst = [], [], [], [], []
    BL3_acc_lst, BL3_y_pred_lst = [], []
    for _ in range(unlearn_times):
        BL3RepDetExtractor_model = BL3RepDetExtractor(input_dim=X_train.shape[1]).to(device)
        BL3Classifier_model = BL3Classifier(output_dim=output_dim).to(device)
        BL3RepDetExtractor_model.load_state_dict(torch.load(get_ori_BL3RepDetExtractormodel_save_path(dataset_name=dataset_name, ori_BL3_training_epochs=ori_BL3_training_epochs), map_location=device))
        BL3Classifier_model.load_state_dict(torch.load(get_ori_BL3Classifier_model_save_path(dataset_name=dataset_name, ori_BL3_training_epochs=ori_BL3_training_epochs), map_location=device))
        
        
        # load baseline 3 prep model
        BL3DecoderMIhx_model = BL3DecoderMIhx(output_dim=X_train.shape[1]-unl_feature_num).to(device)
        BL3MIhy_model = BL3MIhy(output_dim=output_dim).to(device)
        BL3MIhz_model = BL3MIhz(unl_feature_num=unl_feature_num).to(device)
        BL3DecoderMIhx_model.load_state_dict(torch.load(get_BL3DecoderMIhx_model_save_path(dataset_name=dataset_name, ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs), map_location=device))
        BL3MIhy_model.load_state_dict(torch.load(get_BL3MIhy_model_save_path(dataset_name=dataset_name, ori_BL3_training_epochs=ori_BL3_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs), map_location=device))
        BL3MIhz_model.load_state_dict(torch.load(get_BL3MIhz_model_save_path(dataset_name=dataset_name, ori_BL3_training_epochs=ori_BL3_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs), map_location=device))
        
        prep_running_time = np.loadtxt(get_BL3_prep_running_time_save_path(dataset_name=dataset_name, ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs)).item()
        
        
        # baseline 3
        BL3RepDetExtractor_model, BL3Classifier_model, BL3_unlearning_time, BL3_task_loss, BL3_MI_loss = train_unlearning_BL3(
            BL3RepDetExtractor_model=BL3RepDetExtractor_model, BL3Classifier_model=BL3Classifier_model,
            BL3DecoderMIhx_model=BL3DecoderMIhx_model, BL3MIhy_model=BL3MIhy_model, BL3MIhz_model=BL3MIhz_model,
            X_train=X_train, y_train=y_train, device=device, 
            BL3_lamda1=BL3_lamda1, BL3_lamda2=BL3_lamda2, BL3_lamda3=BL3_lamda3,
            epochs=BL3_unlearning_epochs, output_dim=output_dim, unl_feature_num=unl_feature_num
        )
        # evaluate baseline 3
        BL3_acc, BL3_y_pred, _ = evaluate_BL3(BL3RepDetExtractor_model=BL3RepDetExtractor_model, BL3Classifier_model=BL3Classifier_model, X_test=X_test, y_test=y_test, device=device)
        BL3_running_time = prep_running_time + BL3_unlearning_time
        
        BL3RepDetExtractor_model_lst.append(BL3RepDetExtractor_model)
        BL3Classifier_model_lst.append(BL3Classifier_model)
        BL3_running_time_lst.append(BL3_running_time)
        BL3_task_loss_lst.append(BL3_task_loss)
        BL3_MI_loss_lst.append(BL3_MI_loss)
        BL3_acc_lst.append(BL3_acc)
        BL3_y_pred_lst.append(BL3_y_pred)
        
    
    
    # save res 
    np.savetxt(get_BL3_model_acc_save_path(dataset_name=dataset_name, ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, BL3_unlearning_epochs=BL3_unlearning_epochs, BL3_lamda1=BL3_lamda1, BL3_lamda2=BL3_lamda2, BL3_lamda3=BL3_lamda3), [np.mean(BL3_acc_lst), np.std(BL3_acc_lst)])
    np.savetxt(get_BL3_model_running_time_save_path(dataset_name=dataset_name, ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, BL3_unlearning_epochs=BL3_unlearning_epochs, BL3_lamda1=BL3_lamda1, BL3_lamda2=BL3_lamda2, BL3_lamda3=BL3_lamda3), [np.mean(BL3_running_time_lst), np.std(BL3_running_time_lst)])
    
    # save model
    for unlearn_times_idx, (BL3RepDetExtractor_model, BL3Classifier_model, BL3_y_pred, BL3_task_loss, BL3_MI_loss) in enumerate(zip(BL3RepDetExtractor_model_lst, BL3Classifier_model_lst, BL3_y_pred_lst, BL3_task_loss_lst, BL3_MI_loss_lst)):
        torch.save(BL3RepDetExtractor_model.state_dict(), get_BL3RepDetExtractor_model_save_path(dataset_name=dataset_name, ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, BL3_unlearning_epochs=BL3_unlearning_epochs, BL3_lamda1=BL3_lamda1, BL3_lamda2=BL3_lamda2, BL3_lamda3=BL3_lamda3, unlearn_times_idx=unlearn_times_idx))
        torch.save(BL3Classifier_model.state_dict(), get_BL3Classifier_model_save_path(dataset_name=dataset_name, ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, BL3_unlearning_epochs=BL3_unlearning_epochs, BL3_lamda1=BL3_lamda1, BL3_lamda2=BL3_lamda2, BL3_lamda3=BL3_lamda3, unlearn_times_idx=unlearn_times_idx))
        np.savetxt(get_BL3_model_y_pred_save_path(dataset_name=dataset_name, ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, BL3_unlearning_epochs=BL3_unlearning_epochs, BL3_lamda1=BL3_lamda1, BL3_lamda2=BL3_lamda2, BL3_lamda3=BL3_lamda3, unlearn_times_idx=unlearn_times_idx), BL3_y_pred)
        np.savetxt(get_BL3_model_train_task_loss_save_path(dataset_name=dataset_name, ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, BL3_unlearning_epochs=BL3_unlearning_epochs, BL3_lamda1=BL3_lamda1, BL3_lamda2=BL3_lamda2, BL3_lamda3=BL3_lamda3, unlearn_times_idx=unlearn_times_idx), BL3_task_loss)
        np.savetxt(get_BL3_model_train_MI_loss_save_path(dataset_name=dataset_name, ori_BL3_training_epochs=ori_BL3_training_epochs, BL3DecoderMIhx_training_epochs=BL3DecoderMIhx_training_epochs, BL3MIhy_training_epochs=BL3MIhy_training_epochs, BL3MIhz_training_epochs=BL3MIhz_training_epochs, BL3_unlearning_epochs=BL3_unlearning_epochs, BL3_lamda1=BL3_lamda1, BL3_lamda2=BL3_lamda2, BL3_lamda3=BL3_lamda3, unlearn_times_idx=unlearn_times_idx), BL3_MI_loss)