import math
from typing import List, Tuple, Optional, Dict
import numpy as np
import torch
from utils.monotonic_vi import train_gp, eval_gp
from utils.virtual_probit import make_virtual_dataset_all

VecPair = Tuple[Tuple[int, ...], float]
class MonoGPAdapter:
    """
    Trains a monotonic GP on standardized (Xn, yn).
    Exposes:
      - fit(X, y, virtual_bins, epochs)
      - predict_mu_var(Q)           # returns (mu, var) on ORIGINAL y scale
      - success_prob(q_vec)         # Phi((mu - Vstar)/sigma) with ORIGINAL Vstar
      - shortfall(q_vec)            # E[(V* - f(q))_+] on ORIGINAL scale
    """
    def __init__(self, K:int, Vstar:float, device=None, dtype=torch.float32,
                 num_inducing=40, num_directions=2,
                 minibatch_size=5, learning_rate_hypers=0.01, learning_rate_ngd=0.1,
                 use_ngd=False, use_ciq=False, num_contour_quadrature=15, tqdm=False):
        self.K = K
        self.Vstar = float(Vstar)  
        self.device = device or (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
        self.dtype  = dtype
        self.model = None
        self.likelihood = None
        self._cfg = dict(num_inducing=num_inducing, num_directions=num_directions,
                         minibatch_size=minibatch_size, minibatch_dim=num_directions,
                         learning_rate_hypers=learning_rate_hypers, learning_rate_ngd=learning_rate_ngd,
                         use_ngd=use_ngd, use_ciq=use_ciq, num_contour_quadrature=num_contour_quadrature,
                         tqdm=tqdm)

        # scalers (filled in fit)
        self.X_mean = None; self.X_std = None
        self.Y_mean = None; self.Y_std = None


    def _phi(self, z: torch.Tensor) -> torch.Tensor:
        return torch.exp(-0.5 * z * z) / math.sqrt(2.0 * math.pi)

    def _Phi(self, z: torch.Tensor) -> torch.Tensor:
        return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))

    def _standardize_XY(self, X: torch.Tensor, y: torch.Tensor):
        eps = 1e-12
        if self.X_mean is None:
            self.X_mean = X.mean(dim=0)
            self.X_std  = X.std(dim=0).clamp_min(eps)
            self.Y_mean = y.mean()
            self.Y_std  = y.std().clamp_min(eps)
        Xn = (X - self.X_mean) / self.X_std
        yn = (y - self.Y_mean) / self.Y_std
        return Xn, yn

    def _normalize_Q(self, Q: torch.Tensor):
        return (Q - self.X_mean) / self.X_std

    def _denorm_mu_var(self, mu_n: torch.Tensor, var_n: torch.Tensor):
        # y = yn*σy + μy,  Var[y] = Var[yn]*(σy^2)
        Y_std = self.Y_std.to(mu_n.device)
        Y_mean= self.Y_mean.to(mu_n.device)
        mu = mu_n * Y_std + Y_mean
        var = var_n * (Y_std ** 2)
        return mu, var

    def fit(self, X: torch.Tensor, y: torch.Tensor, virtual_bins:int, num_epochs:int, mu=1e-2, qcap_vec:List[int]=None):
        X = X.to(self.device, self.dtype)
        y = y.to(self.device, self.dtype)
        qcap = torch.as_tensor(qcap_vec, device=self.device, dtype=self.dtype)
        
        Xn, yn = self._standardize_XY(X, y)
        qcap_vec_n = self._normalize_Q(qcap)
        virtual_dataset = make_virtual_dataset_all(Xn, virt_bins=virtual_bins, qcap_vec=qcap_vec_n)  

        # trainings
        train_dataset = torch.utils.data.TensorDataset(Xn, yn)
        model, likelihood = train_gp(
            train_dataset, virtual_dataset,
            num_epochs=num_epochs,
            inducing_data_initialization=False,
            lr_sched=None, mu=mu,
            **self._cfg
        )
        self.model, self.likelihood = model, likelihood
        self._cached_XY = (X, y)
        return self

    def cache_XY(self, X: torch.Tensor, y: torch.Tensor):
        self._cached_XY = (X.to(self.device, self.dtype), y.to(self.device, self.dtype))

    def update_one(self, x_new: torch.Tensor, y_new: float, virtual_bins:int, num_epochs:int=25):
        X_old, y_old = self._cached_XY
        X = torch.cat([X_old, x_new.view(1,-1).to(self.device, self.dtype)], dim=0)
        y = torch.cat([y_old, torch.tensor([y_new], device=self.device, dtype=self.dtype)], dim=0)
        self.cache_XY(X, y)
        self.fit(X, y, virtual_bins, num_epochs=num_epochs)
        return self

    def predict_mu_var(self, Q: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Returns (mu, var) on ORIGINAL y scale.
        """
        Q = Q.to(self.device, self.dtype)
        Qn = self._normalize_Q(Q)
        test_ds = torch.utils.data.TensorDataset(Qn, torch.zeros(Qn.shape[0], device=self.device, dtype=self.dtype))
        mu_n, var_n = eval_gp(
            test_ds, self.model, self.likelihood,
            num_directions=self._cfg["num_directions"],
            minibatch_size=Qn.shape[0],
            minibatch_dim=self._cfg["num_directions"]
        )

        mu, var = self._denorm_mu_var(mu_n, var_n)
        return mu, var


    def success_prob(self, q_vec: torch.Tensor) -> torch.Tensor:
        """
        Phi((mu - Vstar)/sigma) on ORIGINAL scale.
        """
        Q = q_vec.unsqueeze(0) if q_vec.ndim == 1 else q_vec
        mu, var = self.predict_mu_var(Q)       # original scale
        std = torch.sqrt(torch.clamp(var, min=1e-10))

        z = (mu - self.Vstar) / (std + 1e-6)

        return torch.clamp(0.5 * (1.0 + torch.erf(z / math.sqrt(2.0))), 0.0, 1.0).squeeze()


    def shortfall(self, q_vec: torch.Tensor) -> torch.Tensor:
        """
        S(q) = E[(V* - f(q))_+] on ORIGINAL scale:
               (V*-mu)*Phi(z) + sigma*phi(z),  z=(V*-mu)/sigma
        """
        Q = q_vec.unsqueeze(0) if q_vec.ndim == 1 else q_vec
        mu, var = self.predict_mu_var(Q)       # original scale
        std = torch.sqrt(torch.clamp(var, min=1e-10))
        Vstar = torch.as_tensor(self.Vstar, device=mu.device, dtype=mu.dtype)
        z = (Vstar - mu) / (std + 1e-6)
        Phi = self._Phi(z); phi = self._phi(z)
        S = (Vstar - mu) * Phi + std * phi
        return torch.clamp(S, min=0.0).squeeze()


def build_success_prob_fn_multisrc_monogp(points: List[VecPair],
                                          Vstar: float,
                                          q_cap_vec: List[int],
                                          gp_state: dict | None = None,
                                          device=None,
                                          dtype=torch.float32,
                                          virt_bins:int=20,
                                          num_inducing:int=40,
                                          minibatch_size:int=5,
                                          mu:float=1e-2,
                                          num_direction:int = 2,
                                          epochs_init:int=120,
                                          epochs_update:int=25):
    if len(points) == 0:
        raise ValueError("Empty points for multi-source F (Monotone GP).")
    K = len(points[0][0])
    device = device or (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))

    X_np = np.array([list(q) for (q, _) in points], dtype=np.float64)  # [N, K]
    y_np = np.array([v for (_, v) in points], dtype=np.float64)        # [N]
    X = torch.tensor(X_np, device=device, dtype=dtype)
    y = torch.tensor(y_np, device=device, dtype=dtype)

    if gp_state is None:
        gp = MonoGPAdapter(K, Vstar, minibatch_size=minibatch_size, num_inducing= num_inducing, num_directions=num_direction,device=device, dtype=dtype)
        gp.cache_XY(X, y)
        gp.fit(X, y, virtual_bins=virt_bins, num_epochs=epochs_init, mu=mu, qcap_vec=q_cap_vec)
        gp_state = {"gp": gp}
    else:
        gp = gp_state["gp"]
        gp.cache_XY(X, y)
        gp.fit(X, y, virtual_bins=virt_bins, num_epochs=epochs_update, mu=mu, qcap_vec=q_cap_vec)

    def F(q_vec: torch.Tensor) -> torch.Tensor:
        return gp.success_prob(q_vec.to(device=device, dtype=dtype))

    def S(q_vec: torch.Tensor) -> torch.Tensor:
        return gp.shortfall(q_vec.to(device=device, dtype=dtype))

    return F, S, K, gp_state
