from typing import List

import numpy as np
import torch

from error_sampler.ErrorSampler import ErrorSampler
from models.qr_models.QuantileRegression import QuantileRegression

from nnkde.core import NNKCDE
from numpy.random import choice

# source: https://github.com/lee-group-cmu/NNKCDE/tree/master/python/src/nnkcdea
class NNKCDEErrorSampler(ErrorSampler):

    def __init__(self):
        super().__init__()
        self.nnkcde = NNKCDE(k=20)

    def fit(self, x_train, z_train, y_train, errors_train, deleted_train, x_val, z_val, y_val, errors_val, deleted_val, **kwargs):
        new_z_train = torch.cat([z_train, z_val], dim=0)
        new_d_train = torch.cat([deleted_train, deleted_val], dim=0)
        new_error_train = torch.cat([errors_train, errors_val], dim=0)
        self.min_val = errors_train.min().item()
        self.max_val = errors_train.max().item()
        self.nnkcde.fit(new_z_train[~new_d_train].detach().cpu().numpy(), new_error_train[~new_d_train].detach().cpu().numpy())
        # self.forest.

    def sample_error(self, x_test, z_test):
        error_grid = np.linspace(self.min_val, self.max_val, 10000).astype(np.float32)
        density = self.nnkcde.predict(z_test.detach().cpu().numpy(), error_grid, bandwidth='scott')
        sampled_errors = []
        for i in range(len(z_test)):
            normalized_density = density[i].copy()
            normalized_density /= normalized_density.sum()
            # normalized_density = normalized_density.astype(np.float32)
            pde = choice(error_grid, 1, p=normalized_density)
            sampled_errors += [pde.item()]
        return torch.Tensor(sampled_errors).to(z_test.device)

    @property
    def name(self) -> str:
        return "nnkcde_error_sampler"
