import joblib
from quantile_forest import RandomForestQuantileRegressor
import numpy as np
import torch

from models.abstract_models.LearningModel import LearningModel

from models.qr_models.PredictionIntervalModel import PredictionIntervalModel, PredictionIntervals


class RFQR(LearningModel, PredictionIntervalModel):


    def __init__(self, dataset_name: str, saved_models_path: str, seed: int, alpha: float,
                 max_depth: int = None, n_estimators: int = 100):
        LearningModel.__init__(self, dataset_name, saved_models_path, seed)
        self.alpha = alpha
        self.max_depth = max_depth
        self.n_estimators = n_estimators
        self.model = RandomForestQuantileRegressor(n_estimators=self.n_estimators, max_depth=self.max_depth)

    def fit_xy_aux(self, x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs=1000, batch_size=64,
                   n_wait=20,
                   **kwargs):
        new_x_train = torch.cat([x_train, x_val], dim=0).cpu().numpy()
        new_y_train = torch.cat([y_train, y_val], dim=0).cpu().numpy()
        # new_y_deleted = torch.cat([y_train, y_val], dim=-1).cpu()
        self.model.fit(new_x_train, new_y_train)

    def fit(self, x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs=1000, batch_size=64, n_wait=20,
            **kwargs):
        self.fit_xy(x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs=epochs, batch_size=batch_size,
                    n_wait=n_wait, **kwargs)

    def construct_uncalibrated_intervals(self, x: torch.Tensor) -> PredictionIntervals:
        pred = self.model.predict(x.cpu().numpy(), quantiles=[self.alpha / 2, 1-self.alpha / 2])
        low_pred, high_pred = pred[:, 0], pred[:, 1]
        device = x.device
        intervals = torch.from_numpy(np.concatenate([low_pred[..., None], high_pred[..., None]], axis=-1)).to(device)
        return PredictionIntervals(intervals)

    def store_model(self):
        joblib.dump(self.model, self.get_model_save_path())

    def load_model(self):
        self.model = joblib.load(self.get_model_save_path())

    def eval(self):
        pass

    @property
    def name(self) -> str:
        return "rf_qr"
