import sys
from pathlib import Path
import argparse
sys.path.append(str(Path.cwd()))

import pandas as pd
import numpy as np

from data import get_dataset
from nrgboost.tree.ensemble import GenerativeBoostingTreeEstimator

parser = argparse.ArgumentParser(
                    prog='NRGBoost',
                    description='Sample from best NRGBoost model')
parser.add_argument('dataset_name', type=str)
parser.add_argument('-s', '--seed', type=int, default=99999)
parser.add_argument('-b', '--burn_in', type=int, default=100)
parser.add_argument('-f', '--fold', type=int, default=0)
parser.add_argument('--sampling_folds', type=int, default=5)
args = parser.parse_args()

sampling_seed = args.seed
fold = args.fold
steps = args.burn_in + 1
dataset_name = args.dataset_name

samples_base_folder = Path(f'nrgb/samples/{dataset_name}/fold_{fold}')
    
results_df = pd.read_csv(f'nrgb/results/{dataset_name}/fold_{fold}/results.csv', index_col=[0, 1, 2])
max_leaves, shrinkage, best_round = results_df['val'].idxmax()

print('best:', max_leaves, shrinkage, best_round)

model_path = Path(f'nrgb/models/{dataset_name}/fold_{fold}/leaves_{max_leaves}/shrinkage_{shrinkage}.pkl')
bst = GenerativeBoostingTreeEstimator.load(model_path)

dataset = get_dataset(dataset_name)
train_df, val_df, test_df = dataset.load_fold(fold, fixed_point=True)
start_val, start_test, num_samples = np.cumsum([len(train_df), len(val_df), len(test_df)])

for sampling_fold in range(args.sampling_folds):
    samples_folder = samples_base_folder / f'sampling_fold_{sampling_fold}'
    samples_folder.mkdir(parents=True, exist_ok=True)
        
    samples = bst.sample_points(steps, num_samples, stop_round=best_round, initial_samples='initial', seed=[fold, sampling_fold, sampling_seed], num_threads=8)
    samples = samples.reshape((num_samples, steps, -1))
    
    np.save(samples_folder / 'train_full.npy', samples[:start_val])
    np.save(samples_folder / 'train.npy', samples[:start_val, -1])
    
    np.save(samples_folder / 'val_full.npy', samples[start_val:start_test])
    np.save(samples_folder / 'val.npy', samples[start_val:start_test, -1])
    
    np.save(samples_folder / 'test_full.npy', samples[start_test:])
    np.save(samples_folder / 'test.npy', samples[start_test:, -1])
