import random
import pickle
import numpy as np
import pandas as pd
import functools
from tqdm import tqdm
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.model_selection import KFold, cross_validate, GridSearchCV
from sklearn.base import clone


def cv_select(X_train, Y_train, cv=5):
    param_grid = [{
    "learning_rate": [0.05, 0.1, 0.2],
    "max_iter": [300, 600, 900],
    "max_depth": [None, 6, 10],
    "max_leaf_nodes": [10, 15, 31],
    "min_samples_leaf": [10, 20, 50],
    "l2_regularization": [1e-3, 1e-2, 1e-1],
    "max_bins": [255],
    }]
    base = HistGradientBoostingRegressor(random_state=42, early_stopping=False)
    gs = GridSearchCV(
        estimator=base,
        param_grid = param_grid,
        cv=cv, scoring='r2', refit=True, n_jobs=-1, verbose=1
    )
    gs.fit(X_train, Y_train)
    print("best gs_cv params:", gs.best_params_)
    print("best gs_cv score:", gs.best_score_)
    train_score = gs.best_estimator_.score(X_train, Y_train)
    print("gs validation score for train_dict: ", train_score)
    return gs.best_estimator_

def biased_sample(datasets, sample_size, seed=42):
    B = sample_size
    if seed is not None:
        random.seed(seed)
    sorted_datasets = sorted(datasets.items(), key=lambda kv: kv[1])
    n = len(sorted_datasets)
    frac = 0.05
    index = int(frac*n)
    bottom_datasets = sorted_datasets[:index]
    bottom_keys = [k for k, v in bottom_datasets]
    sampled_bottoms = random.sample(bottom_keys, B)
    bottom_dict = {k: datasets[k] for k in sampled_bottoms}
    return bottom_dict

def uniform_sample(datasets, sample_size, seed=42):
    B = sample_size
    if B > len(datasets):
        raise ValueError("Sampling size B is larger than the length of datasets")

    if seed is not None:
        random.seed(seed)
    sampled_keys = random.sample(list(datasets.keys()), B)
    sampled_dict = {k: datasets[k] for k in sampled_keys}
    return sampled_dict

def prepare_xy(sampled_dict):
    xs = list(sampled_dict.keys())
    ys = list(sampled_dict.values())

    X = []
    for x in tqdm(xs):
        X.append(featurize(x))
    X = np.array(X)
    Y = np.array(ys)
    return X, Y
    

@functools.cache
#num_blocks: 18 for seh, 4 for tf
def symbol_ohe(symbol, num_blocks=4):
    symbols = '0123456789abcdefghijklmnopqrstuvwxyz' + \
              'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\()*+,-./:;<=>?@[\]^_`{|}~'
    zs = np.zeros(num_blocks)
    zs[symbols.index(symbol)] = 1.0
    return zs

def featurize(x):
    return np.concatenate([symbol_ohe(c) for c in x])

if __name__ == "__main__":
    with open('datasets/tfbind8/tfbind8-exact-v0-all.pkl', 'rb') as f:
        oracle_d = pickle.load(f)

    munge = lambda x: ''.join([str(c) for c in list(x)])
    datasets = {munge(x) : float(y) for x, y in zip(oracle_d['x'], oracle_d['y'])}

    test_size = 10_000
    test_dict = uniform_sample(datasets, test_size, None)
    X_test, Y_test = prepare_xy(test_dict)
    
    B_values = [500]    
    for b in B_values :
        top_dict = biased_sample(datasets, b)
        with open(f'datasets/bottom_offline_sample{b}.pkl', 'wb') as f:
            pickle.dump(top_dict, f)
        X_train, Y_train = prepare_xy(top_dict)
        best_model = cv_select(X_train, Y_train)
        test_score = best_model.score(X_test, Y_test)
        print('gs test_score for 10_000: ', test_score)
        with open(f'datasets/proxy/bottom_proxy_sample{b}.pkl', 'wb') as f:
            pickle.dump(best_model, f)
            print('Saved proxy.')