import argparse
import torch
torch.manual_seed(3407)
import numpy as np 

from copy import deepcopy
from sklearn.model_selection import train_test_split

from utils import get_dataset_from_OPENML, preprocess_df
from config import NOMINIAL_COLS, NUMERICAL_COLS, UNL_FEATURE_NUM
from config import get_ori_BL3RepDetExtractormodel_save_path, get_ori_BL3Classifier_model_save_path, get_ori_BL3_model_train_loss_save_path
from utils import train_BL3Backbone



if __name__ == "__main__":
    parser = argparse.ArgumentParser() 
    
    parser.add_argument('--dataset_name', default='EYE_MOVEMENTS', type=str, help='the name of the dataset')
    parser.add_argument('--device', default='cuda', type=str, help='device')
    parser.add_argument('--ori_BL3_training_epochs', default=1500, type=int, help='original model training epochs')    
    
    args = parser.parse_args()
    
    dataset_name = args.dataset_name
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    ori_BL3_training_epochs = args.ori_BL3_training_epochs
    unl_feature_num = UNL_FEATURE_NUM[dataset_name]
    
    # get data 
    df = get_dataset_from_OPENML(dataset_name=dataset_name)
    output_dim = len(df[df.columns[-1]].unique())
    
    # split data
    nominial_cols, numerical_cols = NOMINIAL_COLS[dataset_name], NUMERICAL_COLS[dataset_name]
    X_processed, y = preprocess_df(df=df, numerical_cols=numerical_cols, nominial_cols=nominial_cols)
    if dataset_name == 'COMPASS': X_processed = X_processed.toarray()
    # split dataset for original training
    X_train, X_test, y_train, y_test = train_test_split(X_processed, y, test_size=0.2, random_state=42)
    # split dataset for retraining from scratch
    RT_X_processed = deepcopy(X_processed[:, unl_feature_num:])
    RT_X_train, RT_X_test, _, _ = train_test_split(RT_X_processed, y, test_size=0.2, random_state=42)
    
    # train baseline 3's backbone model 
    BL3RepDetExtractor_model, BL3Classifier_model, BL3Backbone_training_time, BL3_train_loss_lst = train_BL3Backbone(X_train=X_train, y_train=y_train, output_dim=output_dim, device=device, epochs=ori_BL3_training_epochs)
    
    # save model
    torch.save(BL3RepDetExtractor_model.state_dict(), get_ori_BL3RepDetExtractormodel_save_path(dataset_name=dataset_name, ori_BL3_training_epochs=ori_BL3_training_epochs))
    torch.save(BL3Classifier_model.state_dict(), get_ori_BL3Classifier_model_save_path(dataset_name=dataset_name, ori_BL3_training_epochs=ori_BL3_training_epochs))
    np.savetxt(get_ori_BL3_model_train_loss_save_path(dataset_name=dataset_name, ori_BL3_training_epochs=ori_BL3_training_epochs), BL3_train_loss_lst)
    
    
