from typing import List

import numpy as np
import torch

from error_sampler.ErrorSampler import ErrorSampler
from models.qr_models.QuantileRegression import QuantileRegression
import rfcde

# pip install rfcde
class RFCDEErrorSampler(ErrorSampler):

    def  __init__(self):
        super().__init__()
        # Fit the model
        n_trees = 100
        mtry = 10
        node_size = 10
        n_basis = 15
        self.forest = rfcde.RFCDE(n_trees=n_trees, mtry=mtry, node_size=node_size, n_basis=n_basis)

    def fit(self, x_train, z_train, y_train, errors_train, deleted_train, x_val, z_val, y_val, errors_val, deleted_val, **kwargs):
        new_z_train = torch.cat([z_train, z_val], dim=0)
        new_d_train = torch.cat([deleted_train, deleted_val], dim=0)
        new_error_train = torch.cat([errors_train, errors_val], dim=0)
        self.forest.train(new_z_train[~new_d_train].detach().cpu().numpy().astype(np.double), new_error_train[~new_d_train].detach().cpu().numpy().astype(np.double))
        # self.forest.

    def sample_error(self, x_test, z_test):
        rnd_quantile_levels = torch.rand(len(z_test)).detach().cpu().numpy().astype(np.double)
        sampled_errors = torch.zeros(len(z_test)).to(z_test.device)
        for i in range(len(z_test)):
            sampled_errors[i] = self.forest.predict_quantile(z_test.detach().cpu().numpy().astype(np.double)[i], rnd_quantile_levels[i].item()).item()
        # sampled_errors = self.forest.predict_quantile(z_test.detach().cpu().numpy().astype(np.double), rnd_quantile_levels)
        # return torch.Tensor(sampled_errors).to(z_test.device)
        return sampled_errors

    @property
    def name(self) -> str:
        return "rfcde_error_sampler"
