import sys
from pathlib import Path
import argparse
sys.path.append(str(Path.cwd()))

import pandas as pd
import numpy as np
from math import ceil

from data import get_dataset
from nrgboost.tree.ensemble import GenerativeBoostingTreeEstimator

parser = argparse.ArgumentParser(
                    prog='NRGBoost',
                    description='Sample from best NRGBoost model')
parser.add_argument('dataset_name', type=str)
parser.add_argument('-s', '--seed', type=int, default=99999)
parser.add_argument('-b', '--burn_in', type=int, default=100)
parser.add_argument('-c', '--num_chains', type=int, default=64)
parser.add_argument('-d', '--downsample', type=int, default=10)
parser.add_argument('-f', '--fold', type=int, default=0)
parser.add_argument('--sampling_folds', type=int, default=5)
args = parser.parse_args()


dataset_name = args.dataset_name
sampling_seed = args.seed
fold = args.fold
num_chains = args.num_chains
downsample = args.downsample
burn_in = args.burn_in


samples_base_folder = Path(f'nrgb/samples/{dataset_name}/fold_{fold}')

results_df = pd.read_csv(f'nrgb/results/{dataset_name}/fold_{fold}/results.csv', index_col=[0, 1, 2])
max_leaves, shrinkage, best_round = results_df['val'].idxmax()

model_path = Path(f'nrgb/models/{dataset_name}/fold_{fold}/leaves_{max_leaves}/shrinkage_{shrinkage}.pkl')
bst = GenerativeBoostingTreeEstimator.load(model_path)

dataset = get_dataset(dataset_name)
train_df, val_df, test_df = dataset.load_fold(0, fixed_point=True)
num_train, num_val, num_test = len(train_df), len(val_df), len(test_df)

burn_in -= downsample - 1

num_samples_per_chain_train = ceil(num_train / num_chains)*downsample
num_samples_per_chain_val = ceil(num_val / num_chains)*downsample
num_samples_per_chain_test = ceil(num_test / num_chains)*downsample

num_chains_train = num_chains_val = num_chains_test = num_chains

for sampling_fold in range(args.sampling_folds):

    samples_folder = samples_base_folder / f'sampling_fold_{sampling_fold}'
    samples_folder.mkdir(parents=True, exist_ok=True)
    
    prng = np.random.default_rng([sampling_fold, sampling_seed])
    
    samples = bst.sample_points(num_samples_per_chain_train, num_chains_train, stop_round=best_round, initial_samples='initial', burn_in=burn_in, seed=prng, num_threads=8)
    np.save(samples_folder / 'train_full.npy', samples)
    np.save(samples_folder / 'train.npy', samples[downsample - 1::downsample][-num_train:])
    
    samples = bst.sample_points(num_samples_per_chain_val, num_chains_val, stop_round=best_round, initial_samples='initial', burn_in=burn_in, seed=prng, num_threads=8)
    np.save(samples_folder / 'val_full.npy', samples)
    np.save(samples_folder / 'val.npy', samples[downsample - 1::downsample][-num_val:])
    
    samples = bst.sample_points(num_samples_per_chain_test, num_chains_test, stop_round=best_round, initial_samples='initial', burn_in=burn_in, seed=prng, num_threads=8)
    np.save(samples_folder / 'test_full.npy', samples)
    np.save(samples_folder / 'test.npy', samples[downsample - 1::downsample][-num_test:])