import os
from pathlib import Path
from typing import Optional, List

from torch import nn
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view

from models.forcast.forcast_base import PredictionOutputType, FCPredictionData
from models.uncertainty.pi_base import (
    PIModel,
    PIModelPrediction,
    PIPredictionStepData,
    PICalibData,
    PICalibArtifacts,
)
from models.uncertainty.dist_match.utils import (
    match_ks_stat,
    match_ks_p_val,
    match_rand,
    match_mi,
    match_kl,
    match_wd,
)
from models.uncertainty.dist_match.tree import DistMatchQRF
from models.uncertainty.dist_match.mask_cache import MaskCacheManager

from utils.calc_np import calc_residuals


class DistMatch(PIModel):
    def __init__(self, **kwargs):
        super(DistMatch, self).__init__(
            use_dedicated_calibration=True,
            fc_prediction_out_modes=(PredictionOutputType.POINT,),
        )
        self.qrf: DistMatchQRF = None
        self._update_nodes = kwargs.get("update_nodes", True)
        self._past_window_len = kwargs.get("past_window_len", 100)
        self._match_threshold = kwargs.get("match_threshold", 0.8)
        self._qrf_param = kwargs.get("qrf_param", dict())
        self._beta_calc_bins = kwargs.get("beta_calc_bins", 5)
        self._matcher_param = kwargs.get("matcher_param", dict())
        self._matcher: callable | nn.Module
        self._set_matcher(kwargs.get("match_method", "ks"))
        self._matcher_trainable = False
        self._n_train_samples = None
        self._data_mode = kwargs.get("data_mode", "error")
        self._input_mode = kwargs.get("input_mode", "normal")

    def _match(self, x1, x2) -> bool:
        return self._matcher(x1, x2) < self._match_threshold

    def _set_matcher(self, method: str) -> float:
        self._matcher = None
        self._matcher_method = method
        match self._matcher_method:
            case "ks_stat":
                self._matcher = match_ks_stat
            case "ks":
                self._matcher = match_ks_p_val
            case "mi":
                self._matcher = match_mi
            case "rand":
                self._matcher = match_rand
            case "wd":
                self._matcher = match_wd
            case "kl":
                self._matcher = match_kl
            case _:
                raise NotImplemented(f"Matcher {method} is not implemented")

    def _calibrate(
        self, calib_data: [PICalibData], alphas, **kwargs
    ) -> [PICalibArtifacts]:
        pass

    def calibrate_individual(
        self,
        calib_data: PICalibData,
        alpha,
        calib_artifact: Optional[PICalibArtifacts],
        mix_calib_data: Optional[List[PICalibData]],
        mix_calib_artifact: Optional[List[PICalibArtifacts]],
    ) -> PICalibArtifacts:
        return self._train_qrf(
            ts_id=calib_data.ts_id,
            X_past=calib_data.X_pre_calib,
            Y_past=calib_data.Y_pre_calib,
            X_reg_train=calib_data.X_calib,
            Y_reg_train=calib_data.Y_calib,
            step_offset=calib_data.step_offset,
        )

    def _train_qrf(
        self,
        ts_id,
        X_past,
        Y_past,
        X_reg_train,
        Y_reg_train,
        step_offset,
    ):
        calib_artifacts = PICalibArtifacts()

        Y_hat = self._forcast_service.predict(
            FCPredictionData(
                ts_id=ts_id,
                X_past=X_past,
                Y_past=Y_past,
                X_step=X_reg_train,
                step_offset=step_offset,
            ),
            retrieve_tensor=False,
        ).point

        eps_reg_train = calc_residuals(
            y_hat=Y_hat.squeeze(), y=Y_reg_train.numpy().squeeze()
        )[:, None]
        calib_artifacts.fc_Y_hat = Y_hat
        calib_artifacts.eps = eps_reg_train
        self._calib_eps_last = eps_reg_train
        # Split Calib Residuals in Moving windows of stepwise y predictions

        eps_reg_train = eps_reg_train.squeeze()
        if self._data_mode == "error":
            X_reg = sliding_window_view(
                eps_reg_train, window_shape=self._past_window_len
            )
            X_reg = X_reg[:-1, ..., None]
        else:
            X_reg = sliding_window_view(
                X_reg_train, window_shape=self._past_window_len, axis=-2
            ).swapaxes(-1, -2)
            X_reg = X_reg[:-1, ...]
        Y_reg = eps_reg_train[self._past_window_len :]

        X_reg = self._preprocess_inputs(X_reg)
        match_mask = self._load_cached_mask(ts_id, X_reg)
        match_mask = match_mask < self._match_threshold

        self.qrf = DistMatchQRF(
            **self._qrf_param,
            alpha=0.1,
            n_quantile_bins=self._beta_calc_bins,
            feature_dim=-1,
            matcher=self._match,
            match_mask=match_mask,
            relevance_matcher=self._matcher
        )

        self.qrf.fit(X_reg, Y_reg, preserve_match_mask=True)
        self._n_train_samples = len(X_reg)

        forecast_model = self._forcast_service._model_config.model
        path = f"ABSOLUTE_PATH_TO_THE_ROOT/models_save/uc/qrf_{self._matcher_method}<{self._match_threshold}_{self._data_mode}_{self._input_mode}_{ts_id}_{forecast_model}|{self._past_window_len}.pkl"
        self.qrf.save(path)

        return calib_artifacts

    def pre_predict(self, **kwargs):
        super().pre_predict(**kwargs)
        self.qrf.set_alpha(kwargs["alpha"], n_quantile_bins=self._beta_calc_bins)
        self.qrf.reset_updates(self._n_train_samples)

    def _predict_step(
        self, pred_data: PIPredictionStepData, **kwargs
    ) -> PIModelPrediction:
        # Retrieve data
        alpha, X_step, X_past, Y_past, eps_past = (
            pred_data.alpha,
            pred_data.X_step,
            pred_data.X_past,
            pred_data.Y_past,
            pred_data.eps_past[-self._past_window_len :],
        )
        # Calculate y_hat and prediction interval for current step
        Y_hat = self._forcast_service.predict(
            FCPredictionData(
                ts_id=pred_data.ts_id,
                X_past=X_past,
                Y_past=Y_past,
                X_step=X_step,
                step_offset=pred_data.step_offset_overall,
            )
        ).point

        eps_reg = np.concatenate(
            [
                self._calib_eps_last,
                np.array(eps_past).reshape(-1, 1),
            ]
        )[None, ...]

        eps_reg = np.nan_to_num(eps_reg, nan=0., posinf=0., neginf=0.)

        if self._data_mode == "error":
            x_prev_test = self._preprocess_inputs(
                eps_reg[:, -self._past_window_len - 1 : -1, ...]
            )
            x_test = eps_reg[:, -self._past_window_len :, ...]
        else:
            X_concat = np.concatenate([X_past, X_step])[None, ...]
            x_prev_test = self._preprocess_inputs(
                X_concat[:, -self._past_window_len - 1 : -1, ...]
            )
            x_test = X_concat[:, -self._past_window_len :, ...]

        y_prev_test = eps_reg[0, -1:, 0]
        id_prev_test = np.array([len(X_past)])

        self.qrf.predict_with_update(x_prev_test, y_prev_test, id_prev_test)
        x_test = self._preprocess_inputs(x_test)
        widths = self.qrf.predict(x_test)

        width_low = widths[0][0]
        width_high = widths[0][1]

        pred_int = Y_hat + width_low, Y_hat + width_high
        return PIModelPrediction(pred_interval=pred_int, fc_Y_hat=Y_hat)

    def _load_cached_mask(self, ts_id: int, data: np.ndarray):
        forecast_model = self._forcast_service._model_config.model
        path = f"ABSOLUTE_PATH_TO_THE_ROOT/models_save/uc/cache_{self._matcher_method}_{self._data_mode}_{self._input_mode}_{ts_id}_{forecast_model}|{self._past_window_len}.npy"
        manager = MaskCacheManager(
            path=path,
            matcher=self._matcher,
            batch_size=self._qrf_param.get("batch_size", None),
        )
        if self._matcher_param.get("retrain", False):
            manager.put(data, data)
        return manager.get(data, data)

    def _preprocess_inputs(self, inputs: np.ndarray):
        match self._input_mode:
            case "delta":
                return np.diff(inputs, 1, axis=-2)
            case "residual":
                return inputs[..., 1:, :] - inputs[..., :1, :]
        return inputs

    def model_ready(self):
        return True

    def can_handle_different_alpha(self):
        return True
