"""multivariate_polynomial_basis.py: Implements multivariate polynomial approximators.
"""
import copy
import math
import random
import itertools
import numpy as np
import torch
from functools import lru_cache
from contextlib import contextmanager


__author__ = "anonymizedforblindreview"
__version__ = "1.0"
__email__ = "anonymizedforblindreview"


INIT_WEIGHT = 1e-3
INIT_WEIGHT_BERNSTEIN = 0.5
EPSILON = 1e-3


def univar_power(d):
    """Return uni-variate monomial of degree d"""
    return np.poly1d([1, 0]) ** d


@lru_cache(maxsize=None) # memoizing for better performance
def univar_chebychev(d):
    """Return uni-variate Chebyshev polynomial of degree d"""
    if d == 0:
        return np.poly1d([1])
    elif d == 1:
        return np.poly1d([1, 0])
    else:
        return 2 * univar_chebychev(1) * univar_chebychev(d - 1) - univar_chebychev(d - 2)


def univar_bernstein(i, d):
    """Return ith Bernstein polynomial of degree d"""
    x = np.poly1d([1, 0])
    c = math.comb(d, i)
    return c * (x ** i) * ((1-x) ** (d - i))


def function_approx_coeffs(basis, ps, fs):
    """Return sequence of coefficients alpha_j such that sum_j alpha_j basis[j]
    approximates the function with values fs[i] at locations ps[i] according
    least-square error."""

    evals = [[b(p) for b in basis] for p in ps]
    evalspinv = np.linalg.pinv(evals)

    return evalspinv @ fs


class MultiVarPoly(torch.nn.Module):
    def __init__(self, dim=2, degree=2, basis='chebyshev', initialization='constant', flat_init_offset=0.4, coeffs=None, seed=None):
        super().__init__()
        if seed is not None:
            random.seed(seed)
        self.torch_datatype = torch.float32
        self.dim = dim
        self.degree = degree
        self.model = []
        self.basis = basis
        self.training_mode = False

        # override degree with the one given by coeffs
        if coeffs is not None: 
            self.degree = round(pow(len(coeffs),1/dim)-1)

        if basis == 'power':
            self._pregenerate_basis_elements(univar_power)
        elif basis == 'chebyshev':
            self._pregenerate_basis_elements(univar_chebychev)
        elif basis == 'bernstein':
            self._pregenerate_basis_elements(univar_bernstein)
        else:
            raise Exception('Invalid polynomial basis.')

        if coeffs is not None:
            try:
                self.coeffs = torch.tensor(copy.deepcopy(coeffs), requires_grad=True)
            except:
                raise Exception('Cannot parse coefficients')
        else:
            if initialization == 'constant':
                self.coeffs = torch.tensor([INIT_WEIGHT for _ in range((self.degree+1)**self.dim)], requires_grad=True, dtype=self.torch_datatype)
            elif initialization == 'flat':
                if basis != 'bernstein':
                    self.coeffs = torch.tensor([0.0 for _ in range((self.degree+1)**self.dim)], requires_grad=True, dtype=self.torch_datatype)
                    self.coeffs.data[0] = (flat_init_offset)
                else:
                    self.coeffs = torch.tensor([flat_init_offset for _ in range((self.degree+1)**self.dim)], requires_grad=True, dtype=self.torch_datatype) # assuming interval [0,1] for partition of unity to hold       
            elif initialization == 'zero':
                self.coeffs = torch.tensor([0.0 for _ in range((self.degree+1)**self.dim)], requires_grad=True, dtype=self.torch_datatype)                        
            elif initialization == 'zero_higher_deg_6':
                self.coeffs = torch.tensor([random.uniform(-INIT_WEIGHT, INIT_WEIGHT) for _ in range((self.degree+1)**self.dim)], requires_grad=True, dtype=self.torch_datatype)
                i = 0
                while i < len(self.coeffs):
                    for j in range(self.degree+1):
                        if j > 6:
                            with torch.no_grad():
                                self.coeffs[i].fill_(0.0)
                        i += 1
            else:
                if basis == 'bernstein':
                    init_weight_pos = INIT_WEIGHT_BERNSTEIN + EPSILON
                    init_weight_neg = INIT_WEIGHT_BERNSTEIN - EPSILON
                else:
                    init_weight_pos = INIT_WEIGHT
                    init_weight_neg = -INIT_WEIGHT
                self.coeffs = torch.tensor([random.uniform(init_weight_neg, init_weight_pos) for _ in range((self.degree+1)**self.dim)], requires_grad=True, dtype=self.torch_datatype)

    def _pregenerate_basis_elements(self, univar_basis):
        """Generate polynomial basis vectors for later use.
        self.chebyshev_polynomials[0] = monomial of degree 0
        """
        if self.basis == 'bernstein':
            for i in range(self.degree + 1):
                self.model.append(univar_basis(i, self.degree))        
        else:
            for i in range(self.degree + 1):
                self.model.append(univar_basis(i))

    def _eval_basis_1d(self, x):
        return [torch.tensor(b(x), dtype=self.torch_datatype) for b in self.model]

    def _evaluate_multivarpoly_from_basis_elements(self, x):
        """Evaluate multi-variate polynomial as a product of
        uni-variate polynomials in each dimension using a given basis of
        uni-variate polynomials. Basis for space of n-variate polynomials of
        max-degree d.
        """      
        # Evaluate basis functions at each component of x
        # [[T_0(x_0), T_1(x_0), ...]
        # [T_0(x_1), T_1(x_1), ...]
        # ...]
        basis_vals = [torch.stack(self._eval_basis_1d(xi)) for xi in x]

        # Generate n-d indices from flat-indices
        # E.g. deg=3, dim=2. i=3 --> [0, 3]. i=4 --> [1, 0]
        # Can then be used to index basis_vals
        multi_indices = torch.tensor(list(itertools.product(range(self.degree+1), repeat=self.dim)), dtype=torch.long)

        # Gather the appropriate T_{alpha_i}(x_i) values for each dimension
        selected = torch.stack([
            basis_vals[i][multi_indices[:, i]]  # shape: (num_terms,)
            for i in range(self.dim)
        ])  # shape: (n, num_terms)

        # Compute product over dimensions for each term
        basis_prod = selected.prod(dim=0)  # shape: (num_terms,)

        # Dot with coeffs
        return torch.dot(self.coeffs, basis_prod)

    @contextmanager
    def _grad_context(self):
        """Context manager that controls gradient computation based on training mode."""
        if self.training_mode:
            yield  # Normal gradient computation
        else:
            with torch.inference_mode():  # Fastest for inference
                yield

    def set_training_mode(self, mode):
        """Set approximator training mode, i.e. activate or deactivate computation graph."""
        self.training_mode = mode

    def forward(self, x):
        """
        Evaluation / forward pass used for approximators during learn() and train().
        Stacks the result for use in SB3 on_policy_algorithm.py and ppo.py.
        """
        if x.dim() == 1:  # Single point
            return torch.stack([self.evaluate_point(x)])
        else:
            return torch.stack([self.evaluate_point(p) for p in x])

    def evaluate_point(self, x):
        with torch.enable_grad(): # generally enable grad context (required for specific SB3 algorithms, e.g. PPO)...
            with self._grad_context(): # ...but override with specified context
                return self._evaluate_multivarpoly_from_basis_elements(x)

    def evaluate(self, *data):
        grid = np.meshgrid(*data, indexing='ij')
        points = np.stack(grid, axis=-1)
        return torch.from_numpy(np.apply_along_axis(self.evaluate_point, -1, points))

    def evaluate_basis_vectors_at(self, x):
        #return tf.constant([b(x) for b in self.model], dtype=tf.float32)
        return [b(x) for b in self.model]

    def _evaluate_basis_vectors_at(self, x):
        """
        Evaluate basis functions at each component of x.
        If d = self.degree and n = self.dim then it returns
        [[T_0(x_0) cdot T_0(x_1) cdot ... cdot T_0(x_n)], 
        ...
        [[T_d(x_0) cdot T_d(x_1) cdot ... cdot T_d(x_n)]]
        """      
        basis_vals = [torch.stack(self._eval_basis_1d(xi)) for xi in x]

        # Generate n-d indices from flat-indices
        # E.g. deg=3, dim=2. i=3 --> [0, 3]. i=4 --> [1, 0]
        # Can then be used to index basis_vals
        multi_indices = torch.tensor(list(itertools.product(range(self.degree+1), repeat=self.dim)), dtype=torch.long)

        # Gather the appropriate T_{alpha_i}(x_i) values for each dimension
        selected = torch.stack([
            basis_vals[i][multi_indices[:, i]]  # shape: (num_terms,)
            for i in range(self.dim)
        ])  # shape: (n, num_terms)

        # Compute product over dimensions for each term
        return selected.prod(dim=0)  # shape: (num_terms,)

    def fit_l2(self, X, y):
        """
        Fit an n-D polynomial of given degree to data (X, y) in the L2 sense.
        
        X: tensor of shape (N, dim)
        y: tensor of shape (N,)
        
        Returns:
            coeffs: tensor of shape (num_terms,)
        """
        if not isinstance(X, torch.Tensor):
            X = torch.tensor(X, dtype=torch.float32)
        if not isinstance(y, torch.Tensor):
            y = torch.tensor(y, dtype=torch.float32)
        N, dim = X.shape
        assert dim == self.dim, "Input dimension mismatch"

        num_terms = self.coeffs.shape[0]

        # ----- 2. Build design matrix A
        # A[n, k] = Π_{i=1..dim} T_{multi_indices[k,i]}(X[n,i])
        A = torch.zeros((N, num_terms), dtype=X.dtype)

        for n, x in enumerate(X):
            A[n] = self._evaluate_basis_vectors_at(x)

        # ----- 3. Solve least squares
        # Using stable PyTorch lstsq
        coeffs = torch.linalg.lstsq(A, y).solution.squeeze()

        self.coeffs = torch.tensor(copy.deepcopy(coeffs), requires_grad=True)