import torch
import flexcode
from flexcode.regression_models import NN, Lasso, RandomForest, XGBoost
from error_sampler.ErrorSampler import ErrorSampler
from numpy.random import choice

regression_params = {
    NN: {"k": 20},
    Lasso: {},
    RandomForest: {'n_estimators': 100, 'min_samples_split': 20, 'min_samples_leaf': 20},
    XGBoost: {'n_estimators': 100},
}

# git clone https://github.com/lee-group-cmu/FlexCode.git
# pip install FlexCode[all]
class FlexCodeErrorSampler(ErrorSampler):

    def __init__(self):
        super().__init__()
        reg_model = XGBoost
        self.model = flexcode.FlexCodeModel(reg_model, max_basis=31, basis_system="cosine",
                                            regression_params=regression_params[reg_model])

    def fit(self, x_train, z_train, y_train, errors_train, deleted_train, x_val, z_val, y_val, errors_val, deleted_val,
            **kwargs):
        self.model.fit(z_train.cpu().detach().numpy(), errors_train.cpu().detach().numpy())
        self.model.tune(z_val.cpu().detach().numpy(), errors_val.cpu().detach().numpy())

    def sample_error(self, x_test, z_test):
        cdes, grid = self.model.predict(z_test.cpu().detach().numpy(), n_grid=10000)

        sampled_errors = []
        for i in range(len(z_test)):
            normalized_density = cdes[i].copy()
            normalized_density /= normalized_density.sum()
            pde = choice(grid.squeeze(), 1, p=normalized_density.squeeze())
            sampled_errors += [pde.item()]

        return torch.Tensor(sampled_errors).to(z_test.device)

    @property
    def name(self) -> str:
        return "flex_code_error_sampler"
