"""
Gaussian product mixture for multiple experts.
"""

import torch
import sys
import os
from torch.utils.data import DataLoader

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils.Kendall import calculate_kendall_tau
from models.gp_model import GPModel
from models.deep_ensemble import DeepEnsemble


class GaussianProductMixture(torch.nn.Module):
    """
    Product-of-Gaussians fusion.
    """

    def __init__(self, num_experts: int, device=None, jitter: float = 1e-4):
        super().__init__()
        self.num_experts = num_experts
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
        self.jitter = jitter

    def forward(self, means, covs, weights):
        """
        Fuse Gaussian experts.
        """
        if not isinstance(means, torch.Tensor):
            means = torch.tensor(means, dtype=torch.float32)
        if not isinstance(covs, torch.Tensor):
            covs = torch.tensor(covs, dtype=torch.float32)
        if not isinstance(weights, torch.Tensor):
            weights = torch.tensor(weights, dtype=torch.float32)

        means = means.to(self.device)
        covs = covs.to(self.device)
        weights = weights.to(self.device)

        # means: [M, N], covs: [M, N, N]
        M, N = means.shape
        assert M == self.num_experts
        assert covs.shape == (M, N, N)

        covs = 0.5 * (covs + covs.transpose(-1, -2))

        signs = torch.sign(weights)
        alphas = weights.abs()

        # Covariance part.
        alphas_3d = alphas.view(M, 1, 1)
        base_jitter = float(self.jitter)

        def _robust_cholesky(mat: torch.Tensor) -> torch.Tensor:
            jitter = base_jitter
            if mat.dim() == 3:
                eye = torch.eye(
                    mat.size(-1), device=mat.device, dtype=mat.dtype
                ).unsqueeze(0)
            else:
                eye = torch.eye(mat.size(-1), device=mat.device, dtype=mat.dtype)
            for _ in range(6):
                try:
                    return torch.linalg.cholesky(mat + jitter * eye)
                except RuntimeError:
                    jitter *= 10.0
            return torch.linalg.cholesky(mat + jitter * eye)

        precs = torch.cholesky_inverse(_robust_cholesky(covs))
        weighted_precs = alphas_3d * precs
        precision_sum = weighted_precs.sum(dim=0)

        # Mean part.
        means_signed = means * signs.view(M, 1)
        means_expanded = means_signed.unsqueeze(-1)
        prec_mu = torch.matmul(precs, means_expanded).squeeze(-1)
        weighted_prec_mu = alphas.view(M, 1) * prec_mu
        rhs = weighted_prec_mu.sum(dim=0).unsqueeze(-1)

        precision_sum = 0.5 * (precision_sum + precision_sum.transpose(-1, -2))
        precision_chol = _robust_cholesky(precision_sum)
        cov_fuse = torch.cholesky_inverse(precision_chol)
        mean_fuse = torch.cholesky_solve(rhs, precision_chol).squeeze(-1)

        # Marginal std from diagonal of covariance.
        var_fuse = torch.diagonal(cov_fuse, dim1=-2, dim2=-1)
        std_fuse = torch.sqrt(torch.clamp(var_fuse, min=1e-12))

        if N == 1:
            cov_out = None
        else:
            cov_out = cov_fuse

        return mean_fuse, std_fuse, cov_out


class ModelMixture(torch.nn.Module):
    """
    Multi-expert Gaussian mixture.
    """

    def __init__(
        self,
        models,
        weights: torch.Tensor | None = None,
        device=None,
        jitter: float = 1e-6,
    ):
        super().__init__()
        if isinstance(models, torch.nn.ModuleList):
            models = list(models)
        elif isinstance(models, (list, tuple)):
            models = list(models)
        else:
            raise TypeError("models must be list/tuple or torch.nn.ModuleList")
        if len(models) == 0:
            raise ValueError("models must be non-empty")

        self.models = models

        if device is None:
            ref_model = self.models[0]
            if hasattr(ref_model, "device"):
                self.device = ref_model.device
            else:
                self.device = torch.device(
                    "cuda" if torch.cuda.is_available() else "cpu"
                )
        else:
            self.device = device

        if weights is not None:
            if not isinstance(weights, torch.Tensor):
                weights = torch.tensor(weights, dtype=torch.float32)
            weights = weights.to(self.device, dtype=torch.float32).view(-1)
            if weights.numel() != len(self.models):
                raise ValueError("weights length must match number of models")
        self.weights = weights

        self.mixer = GaussianProductMixture(
            num_experts=len(self.models), device=self.device, jitter=jitter
        )

    def forward(
        self,
        X: torch.Tensor,
        weights: torch.Tensor | None = None,
        block_size: int | None = None,
    ):
        """
        Forward prediction for mixture.
        """
        if weights is None:
            if self.weights is None:
                raise ValueError("weights not provided and not set in init")
            weights_in = self.weights
        else:
            if not isinstance(weights, torch.Tensor):
                weights = torch.tensor(weights, dtype=torch.float32)
            weights_in = weights.to(self.device, dtype=torch.float32).view(-1)
            if weights_in.numel() != len(self.models):
                raise ValueError("weights length must match number of models")

        if not isinstance(X, torch.Tensor):
            X = torch.tensor(X, dtype=torch.float32)
        X = X.to(self.device)

        if X.dim() == 1:
            X_in = X.unsqueeze(0)
        else:
            X_in = X

        means_list = []
        covs_list = []

        for model in self.models:
            if hasattr(model, "predict"):
                try:
                    mean, std, cov = model.predict(X_in, enable_grad=True)
                except TypeError:
                    mean, std, cov = model.predict(X_in)
            else:
                raise TypeError("each model must implement predict(X)")
            if mean.dim() > 1:
                mean = mean.view(-1)
            N = mean.shape[0]
            cov = cov.view(N, N)
            means_list.append(mean)
            covs_list.append(cov)

        means = torch.stack(means_list, dim=0)
        covs = torch.stack(covs_list, dim=0)

        mean_fuse, std_fuse, cov_fuse = self.mixer(means, covs, weights_in)

        if X.dim() == 1 or X.shape[0] == 1:
            return mean_fuse, std_fuse
        return mean_fuse, std_fuse, cov_fuse

    def predict(self, X: torch.Tensor):
        out = self.forward(X)
        if isinstance(out, tuple) and len(out) == 2:
            mean, std = out
            return mean, std, None
        return out
