from __future__ import annotations

from copy import deepcopy
import numpy as np
from typing import List

from utils import *
from base_predictor import *
from distribution import *
from kernel import *
from loss import *
from weight_function import WeightFunction, Center


class BaseRKHSWeighting(WeightFunction):
    """
    RKHS weighting of functions model. 
    
    This is the generic class, upon which are built specific instantiations of the model.

    See the original paper : https://hal.science/hal-04236058v2

    Parameters
    ----------
    dist : Distribution
        Probability distribution for sampling the random parameters. "p" in the paper.

    kernel : Kernel
        Similarity kernel defining the RKHS. "K" in the paper.

    base_pred : BasePredictor
        The functions in "RKHS weightings of functions". "phi" in the paper.

    rng : Numpy Random Generator or int or None
        Random number generator or random seed to set the randomness.

    n_mc : int, default=1000
        Number of samples for Monte Carlo estimations.

    use_mc : boolean, default=False
        Whether to use Monte Carlo estimations.
        By default, will attempt to use self._exact_expectations, which can be defined in the inheriting class.
        If it's not defined, will fall back on Monte Carlo.

    Inherits:
    ----------
    WeightFunction: Refer to the docstring of WeightFunction for details on its parameters and attributes.

    """
    
    def __init__(self, 
                 data_x: np.ndarray, 
                 dist: Distribution, 
                 kernel: Kernel, 
                 base_pred: BasePredictor, 
                 rng=None,  
                 data_y=None,
                 n_mc=1000, 
                 use_mc=False, 
                 reuse_mc=True) -> None:
        self.data_x = data_x
        self.data_y = data_y
        self.dist = dist
        self.base_pred = base_pred
        self.rng = np.random.default_rng(rng)
        self.n_mc = n_mc
        self.use_mc = use_mc
        self.reuse_mc = reuse_mc
        super().__init__(kernel=kernel)
    
    def expectations(self, X: np.ndarray) -> np.ndarray:
        """
        Returns the (n_examples, n_centers) array containing
        the value of the expectation for each (center, x) pair.
        """
        if not self.use_mc:
            try:
                return self._exact_expectations(ensure_X_2d(X))
            except:
                pass
        return self._mc_expectations(ensure_X_2d(X))

    def _exact_expectations(self, X: np.ndarray) -> np.ndarray:
        """
        This function should ideally be implemented
        for any new instantiation. Otherwise, Monte Carlo
        will be used, which is slower and less accurate.

        Use self.set_use_mc(False) to use this function.
        """
        raise NotImplementedError
    
    def _exact_expectations_is_implemented(self) -> bool:
        method = self._exact_expectations
        if not callable(method):
            return False
        try:
            method(self.data_x[0:2, :])  # Assuming data_x is a 2D array
            return True
        except NotImplementedError:
            return False
        except Exception:
            return True
    
    def _get_mc_samples(self):
        T = self.get_n_centers()
        if self.reuse_mc and T > 0:
            return self.get_features()
        elif not hasattr(self, 'mc_samples_'):
            n_mc = T if T > 0 else self.n_mc
            self.mc_samples_ = [self.sample_center() for _ in range(n_mc)]
        return self.mc_samples_

    def _mc_expectations(self, X: np.ndarray) -> np.ndarray:
        """
        Calculate the expectations using Monte Carlo
        """
        batch_size = 1000
        mc_samples = self._get_mc_samples()
        n_mc = len(mc_samples)
        output = np.zeros(shape=(X.shape[0], self.get_n_centers()))
        num_batches = n_mc // batch_size
        remainder = n_mc % batch_size
        for i in range(num_batches):
            output += self._partial_mc_expectations(X, mc_samples[i * batch_size: (i+1) * batch_size])
        if remainder > 0:
            output += self._partial_mc_expectations(X, mc_samples[-remainder:])
        return (1 / n_mc) * output
    
    def _partial_mc_expectations(self, X: np.ndarray, mc_samples) -> np.ndarray:
        base_predictions = self.base_pred.eval(mc_samples, X) # m x n_mc
        gram = self.kernel.calculate(mc_samples, self.get_features())
        return base_predictions @ gram

    def get_empty_expectations(self, X: np.ndarray) -> np.ndarray:
        m = X.shape[0]
        T = self.get_n_centers()
        return np.zeros(shape=(m, T))
    
    def output(self, X: np.ndarray) -> np.ndarray:
        """
        Returns the size m array containing
        the output of the prediction model for all examples in X.

        Equations-wise, this is Lambda alpha(X).
        """
        X = ensure_X_2d(X)
        if self.get_n_centers() == 0:
            return np.zeros(shape=X.shape[0])
        if self.use_mc:
            return self.mc_output(X)
        coefs = self.get_coefs()
        return (self.expectations(X) @ coefs).flatten()
    
    def __call__(self, X: np.ndarray) -> np.ndarray:
        return self.output(X)

    def mc_output(self, X: np.ndarray) -> np.ndarray:
        """
        Calculates model output using Monte Carlo.

        Equivalent to approximating the expectations and doing the dot product.
        """
        batch_size = 1000
        X = ensure_X_2d(X)
        mc_samples = self._get_mc_samples()
        n_mc = len(mc_samples)
        output = np.zeros(X.shape[0])
        num_batches = n_mc // batch_size
        remainder = n_mc % batch_size
        for i in range(num_batches):
            output += (1 / n_mc) * self._partial_mc_output(X, mc_samples[i * batch_size: (i+1) * batch_size])
        if remainder > 0:
            output += (1 / n_mc) * self._partial_mc_output(X, mc_samples[-remainder:])
        return output

    def _partial_mc_output(self, X: np.ndarray, mc_samples) -> np.ndarray:
        base_predictions = self.base_pred.eval(mc_samples, X) # m x n_mc
        weight_func_values = self.eval_weight_func_multiple_centers(mc_samples) # n_mc
        return base_predictions @ weight_func_values # m

    def efficient_update(self, center, coef, scale=1.0):
        """
        Multiplies the weight function by scale, then add
        the new center and coefficient.

        The scale is useful when regularizing.
        """
        self *= scale
        return self.add_center(center, coef)
    
    def theta(self):
        try:
            return self.theta_ 
        except:
            try:
                self.theta_ = self.true_theta()
            except NotImplementedError:
                self.theta_ = self.mc_theta()
        return self.theta_
    
    def true_theta(self):
        """
        Calculates theta using an analytical formula.
        Each instantation may implement this function.
        Otherwise, constant will be calculated using Monte Carlo.

        """
        raise NotImplementedError
    
    def mc_theta(self) -> float:
        if self.data_x.shape[0] > 20000:
            data = self.data_x[np.random.choice(self.data_x.shape[0], 20000, replace=False)]
        else:
            data = self.data_x
        return math.sqrt(max(np.diag(self.implicit_kernel(data, data))))

    def iota(self):
        try:
            return self.iota_
        except:
            try:
                self.iota_ = self.true_iota()
            except NotImplementedError:
                self.iota_ = self.mc_iota()
        return self.iota_
    
    def true_iota(self):
        """
        Calculates iota using an analytical formula.
        Each instantation may implement this function.
        Otherwise, constant will be calculated using Monte Carlo.

        """
        raise NotImplementedError
    
    def mc_iota(self) -> float:
        """
        Estimate constant iota using Monte Carlo.
        """
        X = self.data_x
        m = X.shape[0]
        mc_samples = self._get_mc_samples()
        T = len(mc_samples)
        kern_diag = self.kernel.diag(mc_samples, mc_samples)
        nonzero_idx = np.flatnonzero(kern_diag)
        if len(nonzero_idx) == 0:
            return 0
        base_predictions = np.zeros((m, T))
        nonzero_centers = [mc_samples[i] for i in nonzero_idx]
        base_predictions[:, nonzero_idx] = self.base_pred.eval(nonzero_centers, X) # m x T
        rdv = np.sqrt(base_predictions**2 * kern_diag)
        candidates = np.mean(rdv, axis=1)
        return float(max(candidates))
    
    def kappa(self):
        try:
            return self.kappa_
        except:
            try:
                self.kappa_ = self.true_kappa()
            except NotImplementedError:
                self.kappa_ = self.mc_kappa()
        return self.kappa_
    
    def true_kappa(self):
        """
        Calculates kappa using an analytical formula.
        Each instantation may implement this function.
        Otherwise, constant will be calculated using Monte Carlo.

        """
        raise NotImplementedError
    
    def mc_kappa(self) -> float:
        """
        Estimate constant kappa using Monte Carlo.
        """
        X = self.data_x
        m = X.shape[0]
        mc_samples = self._get_mc_samples()
        T = len(mc_samples)
        kern_diag = self.kernel.diag(mc_samples, mc_samples)
        nonzero_idx = np.flatnonzero(kern_diag)
        if len(nonzero_idx) == 0:
            return 0
        base_predictions = np.zeros((m, T))
        nonzero_centers = [mc_samples[i] for i in nonzero_idx]
        base_predictions[:, nonzero_idx] = self.base_pred.eval(nonzero_centers, X) # m x T
        rdv = base_predictions**2 * kern_diag
        candidates_squared = np.mean(rdv, axis=1)
        return float(np.sqrt(max(candidates_squared)))
    
    def operator_norm(self):
        return min(self.theta(), self.iota(), self.kappa())

    def get_n_dim(self):
        return self.dist.n_dim

    def sample_center(self):
        return self.dist.sample()

    def max_output(self):
        if self.use_mc or not self._exact_expectations_is_implemented():
            return max(np.abs(self.mc_output(self.data_x)))
        else:
            return self.norm() * self.operator_norm()

    def copy(self) -> BaseRKHSWeighting:
        return deepcopy(self)
    
    def empty_copy(self):
        """
        Makes and returns a copy of the Model with empty self.centers

        Much more efficient than using self.copy
        """
        copy = self.__class__(data_x=self.data_x, data_y=self.data_y, rng=self.rng)
        for key, value in vars(self).items():
            if key in ['centers']:
                setattr(copy, key, [])
            else:
                setattr(copy, key, value)
        return copy
    
    def implicit_kernel(self, X: np.ndarray, Y: np.ndarray) -> np.ndarray:
        mc_samples = self._get_mc_samples()
        kern_values = self.kernel.calculate(mc_samples, mc_samples) # n_mc x n_mc
        base_pred_X = self.base_pred.eval(mc_samples, X) # m1 x n_mc
        base_pred_Y = self.base_pred.eval(mc_samples, Y) # m2 x n_mc
        output = (1 / self.n_mc**2) * base_pred_X @ kern_values @ base_pred_Y.T
        return output
    
    def l2p_norm_approx(self) -> float:
        mc_samples = [self.sample_center() for _ in range(self.n_mc)]
        outputs = self.eval_weight_func_multiple_centers(mc_samples)
        return math.sqrt(np.mean(outputs**2) / self.n_mc)
    
    def tau_approx(self, X: np.ndarray) -> float:
        """
        Estimate tau using Monte Carlo.
        """
        mc_samples = [self.sample_center() for _ in range(self.n_mc)]
        outputs = self.base_pred.eval(mc_samples, X)
        tau_candidates = np.mean(outputs**2, axis=1)
        return math.sqrt(np.max(tau_candidates))
