import os
import numpy as np
import pandas as pd
import yaml
import pickle
from datetime import datetime

from models.sample import *
from utils.tpnn_logging import *

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--y_dist', type = str, required = True)
# dataset
parser.add_argument('-d', '--dataset_name', type = str, default = 'abalone', required = True)
parser.add_argument('-f', '--fold', type = int, default = 0)
parser.add_argument('--log_dir', type = str, default = './parameters')
# model arch
parser.add_argument('--max_depth', type = int, default = 3)
parser.add_argument('-K', '--K_max', type = int, default = 200)
# update
parser.add_argument('--num_samples', type = int, default = 2000)
parser.add_argument('--step_size', type = float, default = 0.01, help = 'height update step size')
parser.add_argument('--leapfrog_L', type = int, default = 1, help = 'height update leapfrog L')
parser.add_argument('--bg_step_size', type = float, default = 0.01, help = 'bg update leapfrog')
parser.add_argument('--M', type = float, default = 1.)
# hyperparameter
parser.add_argument('--gamma_shape', type = float, default = 2.0, help = 'gamma prior shape parameter')
parser.add_argument('--gamma_scale', type = float, default = 0.01, help = 'gamma prior scale parameter')        # 작을 수록 indicator
parser.add_argument('--c0', type = float, default = 0.001)
# reg
parser.add_argument('--var_height', type = float, default = 0.1)
parser.add_argument('--nu', type = float, default = 10.)
# ber
parser.add_argument('--const_var', type = float, default = 0.1)
parser.add_argument('--const_step_size', type = float, default = 0.01)

parser.add_argument('--seed', type = int, default = 42, help = '')
args = parser.parse_args()

np.random.seed(args.seed)

ALL_FOLD = False

def fold_sampling(config):
    # data
    data_x = pd.read_csv(os.path.join(config['data_path'], config['dataset_name'], "data_x.csv")).to_numpy()
    data_y = pd.read_csv(os.path.join(config['data_path'], config['dataset_name'], 'data_y.csv')).to_numpy()
    n = data_x.shape[0]
    train_n = int(0.8*n); test_n = n-train_n
    perm_idx = np.random.permutation(n)
    train_idx  = perm_idx[:train_n]; test_idx = perm_idx[train_n:]
    train_x = data_x[train_idx, :]; test_x = data_x[test_idx, :]
    train_y = data_y[train_idx, :]; test_y = data_y[test_idx, :]

    # sampling
    sample_start = datetime.now()
    samples = sampling(train_x, train_y, config, test_x, test_y, verbose = 1)
    sample_end = datetime.now()
    sample_expended_time = sample_end - sample_start
    print(f'Sampling took {sample_expended_time.seconds} seconds.')

    # save samples
    with open(os.path.join(config['experiment_name'], 'samples.pickle'), 'wb') as f:
        pickle.dump(samples[-1000:], f)

    # logging
    if config['tpnn_detail_log']:
        tpnns_log_to_csv(samples[1:], config)
    else :
        tpnns_log_to_csv(samples[-10:], config)
    update_log_to_csv(samples[1:], config)

if __name__ == '__main__':
    args = vars(args)

    y_dist = args['y_dist']
    if y_dist == 'normal':
        config_path = './config/config_reg.yaml'
    else :
        config_path = './config/config.yaml'
    with open(config_path, 'rb') as f:
        config = yaml.safe_load(f)
    
    for k, v in args.items():
        if k in config.keys():
            config[k] = v
    if args['y_dist'] == 'normal':
        config['nui']['inv_gamma_nu'] = args['nu']

    if ALL_FOLD:
        for f in range(5):
            config['fold'] = f
            fold_sampling(config)
    else :
        fold_sampling(config)