import argparse
import torch
torch.manual_seed(3407)
import numpy as np 

from copy import deepcopy
from sklearn.model_selection import train_test_split

from utils import get_dataset_from_OPENML, preprocess_df
from config import NOMINIAL_COLS, NUMERICAL_COLS
from config import get_ori_model_save_path
from config import get_BL1_model_save_path, get_BL1_model_acc_save_path, get_BL1_model_y_pred_save_path, get_BL1_model_running_time_save_path, get_BL1_model_train_loss_save_path
from utils import TabularClassifier
from utils import BL1_prep_data, train_classifier, evaluate_classifier

if __name__ == "__main__":
    parser = argparse.ArgumentParser() 
    
    parser.add_argument('--dataset_name', default='ELECTRICITY', type=str, help='the name of the dataset')
    parser.add_argument('--device', default='cuda', type=str, help='device')
    parser.add_argument('--ori_training_epochs', default=2, type=int, help='original model training epochs')    
    parser.add_argument('--BL1_unlearning_epochs', default=2, type=int, help='unlearning epochs for baseline 1')
    parser.add_argument('--unlearn_times', default=10, type=int, help='the number of times for unlearning -- repeat experiments and take avg')
    
    args = parser.parse_args()
    
    dataset_name = args.dataset_name
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    ori_training_epochs = args.ori_training_epochs
    BL1_unlearning_epochs = args.BL1_unlearning_epochs
    unlearn_times = args.unlearn_times
    
    
    # get data 
    df = get_dataset_from_OPENML(dataset_name=dataset_name)
    output_dim = len(df[df.columns[-1]].unique())
    
    # split data
    nominial_cols, numerical_cols = NOMINIAL_COLS[dataset_name], NUMERICAL_COLS[dataset_name]
    X_processed, y = preprocess_df(df=df, numerical_cols=numerical_cols, nominial_cols=nominial_cols)
    # split dataset for original training
    X_train, X_test, y_train, y_test = train_test_split(X_processed, y, test_size=0.2, random_state=42)
    # split dataset for retraining from scratch
    RT_X_processed = deepcopy(X_processed[:, 1:])
    RT_X_train, RT_X_test, _, _ = train_test_split(RT_X_processed, y, test_size=0.2, random_state=42)
    
    
    
    

    
    # load original model
    ori_input_dim = X_train.shape[1]
    ori_model = TabularClassifier(input_dim=ori_input_dim, output_dim=output_dim)
    ori_model.load_state_dict(torch.load(get_ori_model_save_path(dataset_name=dataset_name, ori_training_epochs=ori_training_epochs), map_location=device))
    
    # prep data 
    BL1_model_lst, BL1_training_time_lst, BL1_train_loss_lst = [], [], []
    BL1_acc_lst, BL1_y_pred_lst = [], []
    for _ in range(unlearn_times):
        BL1_model, BL1_X_train = BL1_prep_data(X_train=X_train, model=ori_model)
        # baseline 1
        BL1_model, BL1_training_time, BL1_train_loss = train_classifier(X_train=BL1_X_train, y_train=y_train, output_dim=output_dim, device=device, epochs=BL1_unlearning_epochs)
        # evaluate baseline 1
        BL1_acc, BL1_y_pred, _ = evaluate_classifier(model=BL1_model, X_test=X_test, y_test=y_test, device=device)
        
        BL1_model_lst.append(BL1_model)
        BL1_training_time_lst.append(BL1_training_time)
        BL1_train_loss_lst.append(BL1_train_loss)
        BL1_acc_lst.append(BL1_acc)
        BL1_y_pred_lst.append(BL1_y_pred)
        
    
    
    # save res
    np.savetxt(get_BL1_model_acc_save_path(dataset_name=dataset_name, ori_training_epochs=ori_training_epochs, BL1_unlearning_epochs=BL1_unlearning_epochs), [np.mean(BL1_acc_lst), np.std(BL1_acc_lst)])
    np.savetxt(get_BL1_model_running_time_save_path(dataset_name=dataset_name, ori_training_epochs=ori_training_epochs, BL1_unlearning_epochs=BL1_unlearning_epochs), [np.mean(BL1_training_time_lst), np.std(BL1_training_time_lst)])
    
    # save model 
    for unlearn_times_idx, (BL1_model, BL1_y_pred, BL1_train_loss) in enumerate(zip(BL1_model_lst, BL1_y_pred_lst, BL1_train_loss_lst)):
        torch.save(BL1_model.state_dict(), get_BL1_model_save_path(dataset_name=dataset_name, ori_training_epochs=ori_training_epochs, BL1_unlearning_epochs=BL1_unlearning_epochs, unlearn_times_idx=unlearn_times_idx))
        np.savetxt(get_BL1_model_y_pred_save_path(dataset_name=dataset_name, ori_training_epochs=ori_training_epochs, BL1_unlearning_epochs=BL1_unlearning_epochs, unlearn_times_idx=unlearn_times_idx), BL1_y_pred)
        np.savetxt(get_BL1_model_train_loss_save_path(dataset_name=dataset_name, ori_training_epochs=ori_training_epochs, BL1_unlearning_epochs=BL1_unlearning_epochs, unlearn_times_idx=unlearn_times_idx), BL1_train_loss)

