

import copy
import time

import mlflow
import numpy as np
import yaml

import eval
from src.model_clf import XTab as XTab_clf
from utils.arguments import get_arguments, get_config, print_config_summary
from utils.load_data import Loader
from utils.utils import set_dirs, run_with_profiler, update_config_with_model_dims


def train(config_o, data_loader, save_weights=True, SEED=211):
    """Utility function for training and saving the model.
    Args:
        config (dict): Dictionary containing options and arguments.
        data_loader (IterableDataset): Pytorch data loader.
        save_weights (bool): Saves model if True.

    """
    mask_state_dicts_l = [] 
    enc_state_dicts_l = [] 
    clf_state_dicts_l = [] 

    acc_auc_dicts = {} 
    feature_tr_dicts = {} 
    feature_te_dicts = {} 

    config_o["train_mask"] = True
    #config_o["kfold"] = 10
    new_seed = SEED
    # Don't use index 0 since it is reserved for final training
    for i in range(1,config_o["kfold"]+1):
        
        config = copy.deepcopy(config_o)
    
        # Get data loader for first dataset.
        config["fold_num"] = 0
        data_loader = Loader(config, dataset_name=config["dataset"])
        config = update_config_with_model_dims(data_loader, config)
        
        # Instantiate model
        new_seed = new_seed+17
        config["seed"] = new_seed
        model = XTab_clf(config)
            
            
        # Fit the model to the data
        model.fit(data_loader)
        # Collect the mask model
        mask_state_dicts_l.append(model.mask.state_dict())
        enc_state_dicts_l.append(model.encoder.encoder.state_dict())
        clf_state_dicts_l.append(model.clf.state_dict())

        
        # Run Evaluation
        config = copy.deepcopy(config_o)
        config["fold_num"] = 0
        config["test_mode"] = False
        if i>1:
            config["use_mask_g"] = False
        # return (feature_imp_tr, feature_imp_te, acc, auc)
        feature_imp_tr, feature_imp_te, acc, auc, _, _ = eval.main(config)
        acc_auc_dicts[str(i)]= (acc, auc)
        feature_tr_dicts[str(i)] = feature_imp_tr
        feature_te_dicts[str(i)] = feature_imp_te
        

    np.save(model._results_path+"/"+config["dataset"]+'_mask_state_dicts_l.npy', mask_state_dicts_l)
    np.save(model._results_path+"/"+config["dataset"]+'_enc_state_dicts_l.npy', enc_state_dicts_l)
    np.save(model._results_path+"/"+config["dataset"]+'_clf_state_dicts_l.npy', clf_state_dicts_l)

    np.save(model._results_path+"/"+config["dataset"]+'_acc_auc_dicts.npy', acc_auc_dicts)
    np.save(model._results_path+"/"+config["dataset"]+'_feature_tr_dicts.npy', feature_tr_dicts)
    np.save(model._results_path+"/"+config["dataset"]+'_feature_te_dicts.npy', feature_te_dicts)
    
    
    # Initialize the mask with weighted mask learned during cross-validation
    for k in mask_state_dicts_l[0]:
        mask_state_dicts_l[0][k] = sum([sd[k] for sd in mask_state_dicts_l])/len(mask_state_dicts_l)
   
        
    # Choose the default fold to load
    config = copy.deepcopy(config_o)
    config["fold_num"] = 0
    config["train_mask"] = True
    config["validate"] = False
    config["training_data_ratio"] = 1.0
    config["test_mode"] = True

    # Get data loader for first dataset.
    data_loader = Loader(config, dataset_name=config["dataset"])
    config = update_config_with_model_dims(data_loader, config)

    # Instantiate the final model
    config["seed"] = 717
    config["use_mask_g"] = True
#     config["add_noise"]=False
    model = XTab_clf(config)
    # Overwrite the mask's weights
    model.mask_g.load_state_dict(mask_state_dicts_l[0])
#     model.encoder.encoder.load_state_dict(enc_state_dicts_l[-1])
#     model.clf.load_state_dict(clf_state_dicts_l[-1])
#     model.save_weights()
    
    # Fit the model to the data
    model.fit(data_loader)
    

        
    # Run Evaluation
    config = copy.deepcopy(config_o)
    config["fold_num"] = 0
    config["test_mode"] = True
    config["validate"] = False
    config["training_data_ratio"] = 1.0
    config["use_mask_g"] = True
    config["add_noise"]=False
    # return (feature_imp_tr, feature_imp_te, acc, auc)
    feature_imp_tr, feature_imp_te, acc, auc, _, _ = eval.main(config)
        
    print("Train feature importance")
    print(feature_imp_tr)
    print("Test feature importance")
    print(feature_imp_te)
    print(f"Test accuracy: {acc}")
        
    
    print(acc_auc_dicts)
    
    # Save the config file to keep a record of the settings
    with open(model._results_path + "/config.yml", 'w') as config_file:
        yaml.dump(config, config_file, default_flow_style=False)
    print("Done with training...")




def main(config):
    """Main wrapper function for training routine.

    Args:
        config (dict): Dictionary containing options and arguments.

    """
    # Set directories (or create if they don't exist)
    set_dirs(config)
    # Get data loader for first dataset.
    data_loader = Loader(config, dataset_name=config["dataset"])
    # Add the number of features in a dataset as the first dimension of the model
#     config = update_config_with_model_dims(data_loader, config)
    # Start training and save model weights at the end
    train(config, data_loader, save_weights=True)


def main(config, SEED=57):
    """Main wrapper function for training routine.

    Args:
        config (dict): Dictionary containing options and arguments.

    """
    # Set directories (or create if they don't exist)
    set_dirs(config)
    # Get data loader for first dataset.
    data_loader = Loader(config, dataset_name=config["dataset"])
    # Add the number of features in a dataset as the first dimension of the model
#     config = update_config_with_model_dims(data_loader, config)
    # Start training and save model weights at the end
    train(config, data_loader, save_weights=True, SEED=SEED)


if __name__ == "__main__":
    # Get parser / command line arguments
    args = get_arguments()
    seeds = [211, 369, 317, 79, 54, 654, 34, 167, 4468, 7817] #770

    for SEED in seeds:
        # Get configuration file
        config = get_config(args)
        config["linear"]=False
        config["relu"]=True
        config["relux2"]=False
        config["relux3"]=False
        config["relux5"]=False
        config["c_hdim"] = 1024
        
        if config["dataset"] in ["mnist"]:
            config["epochs"] = 20 
        else:
            config["epochs"] = 40 
        
        
        config["clf_starting_epoch"]= 0 
        config["kfold"]=10
        config["noise_type"]= "gaussian_noise"        # Type of noise to add to. Choices: swap_noise, gaussian_noise, zero_out]
        config["add_noise"]=True
        
        
        config["noise_level"] = 0.3
        config["masking_ratio"]= 0.3                # Percentage of the feature to add noise to
        config["n_subsets"]= 2
        config["overlap"] = 0.75
        
        if config["linear"]:
            arc_name = "_linear_seed_"
        elif config["relu"]:
            arc_name = "_linear_relu_seed_"            
        elif config["relux2"]:
            arc_name = "_linear_relux2_seed_"               
        elif config["relux3"]:
            arc_name = "_linear_relux3_seed_"       
        elif config["relux5"]:
            arc_name = "_linear_relux5_seed_"  
            
        # Overwrite the parent folder name for saving results
#         config["framework"] = config["dataset"] + arc_name + str(SEED)
        config["framework"] = config["dataset"] + arc_name + "class_" + str(config["c_hdim"]) + "_nv_" + str(SEED) + "_gaussian"+str(config["noise_level"])+str(config["masking_ratio"]) + "_" + str(config["n_subsets"])+str(config["overlap"])
        
        # Get a copy of autoencoder dimensions
        dims = copy.deepcopy(config["dims"])
        # Summarize config and arguments on the screen as a sanity check
        print_config_summary(config, args)

        #----- Run Training - with or without profiler
        main(config, SEED=SEED)

