import torch
from torch import distributions as D
from functools import cached_property
import random
from metabeta.data.distributions import Normal, StudentT, Uniform, Bernoulli, NegativeBinomial, ScaledBeta
from metabeta.utils import fullCovary

# -----------------------------------------------------------------------------
probs = torch.tensor([0.10, 0.40, 0.05, 0.25, 0.10, 0.10])
# probs = torch.tensor([1., 0.0, 0.0, 0.0, 0.0, 0.0])
dists = [
    Normal,
    StudentT,
    Uniform,
    Bernoulli,
    NegativeBinomial,
    ScaledBeta,
    ]

def plotX(x):
    import math
    import matplotlib.pyplot as plt
    n, d = x.shape
    
    # Number of unique pairs: d choose 2
    num_plots = d * (d - 1) // 2
    
    # Determine subplot grid size (rows x cols)
    cols = int(math.ceil(math.sqrt(num_plots)))
    rows = int(math.ceil(num_plots / cols))
    
    fig, axs = plt.subplots(rows, cols, figsize=(cols*4, rows*4))
    axs = axs.flatten()  # Flatten to easily iterate
    
    plot_idx = 0
    for i in range(d):
        for j in range(i + 1, d):
            axs[plot_idx].scatter(x[:, i].numpy(), x[:, j].numpy(), s=10)
            axs[plot_idx].set_xlabel(f'x[:, {i}]')
            axs[plot_idx].set_ylabel(f'x[:, {j}]')
            axs[plot_idx].set_title(f'Scatter plot of dims {i} vs {j}')
            plot_idx += 1
    
    # Hide any unused subplots if total subplots > num_plots
    for k in range(plot_idx, len(axs)):
        axs[k].axis('off')
    
    plt.tight_layout()
    plt.show()


# base class
class Task:
    def __init__(self,
                 nu_ffx: torch.Tensor, # ffx prior means
                 tau_ffx: torch.Tensor, # ffx prior stds 
                 tau_eps: torch.Tensor, # noise prior std
                 n_ffx: int, # with bias
                 features: torch.Tensor | None = None,
                 limit: float = 300, # try to sd(y) below this
                 use_default: bool = False, # use default prior parameters for predictors
                 correlate: bool = True, # correlate numerical predictors
                 ):
        
        self.d = n_ffx
        self.features = features
        self.original_limit = limit
        self.limit = limit
        self.use_default = use_default
        self.correlate = correlate
        
        # data distribution
        idx = D.Categorical(probs).sample((n_ffx,))
        self.dist_data_ = [dists[i] for i in idx]
        self.dist_data = []
        self.cidx = []
        
        # ffx distribution
        assert len(nu_ffx) == len(tau_ffx) == self.d, 'dimension mismatch'
        self.nu_ffx = nu_ffx
        self.tau_ffx = tau_ffx
        self.dist_ffx = D.Normal(self.nu_ffx, self.tau_ffx)

        # noise distribution
        self.tau_eps = tau_eps
        self.sigma_eps = D.HalfNormal(self.tau_eps).sample((1,))[0] + 1e-3
    
    
    def _sampleFfx(self) -> torch.Tensor:
        return self.dist_ffx.sample() # type: ignore

    def _sampleFeatures(self, n_samples: int, ffx: torch.Tensor) -> torch.Tensor:
        features = [torch.empty((n_samples, 0))]
        for i in range(len(ffx) - 1):
            weight = ffx[i+1]
            dist = self.dist_data_[i](weight, limit=self.limit, use_default=self.use_default)
            if str(dist)[:4] == 'Bern':
                self.cidx.append(i)
            self.dist_data += [dist]
            x = dist.sample((n_samples, 1))
            self.limit -= (x * weight).abs().max()
            features += [x]
        out = torch.cat(features, dim=-1)
        return out
    
    def _correlate(self, x: torch.Tensor) -> torch.Tensor:
        if self.d < 3 or not self.correlate:
            return x
        mean, std = x.mean(dim=0), x.std(dim=0)
        x_ = (x - mean) / std        
        L = D.LKJCholesky(self.d - 1, 10.0).sample() # roughly within r = 0.6
        R = L @ L.T
        x_cor = (x_ @ L.T) * std + mean
        x_cor[:, self.cidx] = x[:, self.cidx] # preserve categorial
        
        # correlate categorial
        for i in self.cidx:
            nidx = list(range(self.d-1))
            nidx.pop(i)
            j = random.choice(nidx)
            x_cor[:, i] = self._correlateBin(x_cor[:, j], R[i, j])
        return x_cor
    
    def _correlateBin(self, v: torch.Tensor, r: float):
        v = (v - v.mean()) / v.std()
        z = torch.randn_like(v)
        # Create continuous variable z correlated with x
        z = r * v + (1 - r**2)**0.5 * z
        probs = torch.sigmoid(z)
        z = torch.bernoulli(probs)
        return z

    def _addIntercept(self, x: torch.Tensor):
        n_samples = x.shape[0]
        intercept = torch.ones(n_samples, 1)
        out = torch.cat([intercept, x], dim=-1)
        return out

    def _sampleError(self, n_samples: int) -> torch.Tensor:
        eps = torch.randn((n_samples,))
        eps = (eps - torch.mean(eps)) / torch.std(eps)
        return eps * self.sigma_eps

    def sample(self, n_samples: int) -> dict[str, torch.Tensor]:
        raise NotImplementedError
    
    def signalToNoiseRatio(self, y: torch.Tensor, eta: torch.Tensor) -> torch.Tensor:
        eps = y - eta
        ess = (eta - eta.mean()).square().sum()
        rss = eps.square().sum()
        snr = ess / rss
        return snr
    
    def relativeNoiseVariance(self, y: torch.Tensor, eta: torch.Tensor) -> torch.Tensor:
        eps = y - eta
        return eps.var() / y.var()


# -----------------------------------------------------------------------------
# FFX
class FixedEffects(Task):
    def __init__(self,
                 nu_ffx: torch.Tensor, # ffx prior means
                 tau_ffx: torch.Tensor, # ffx prior stds 
                 tau_eps: torch.Tensor, # noise prior std
                 n_ffx: int,
                 **kwargs
                 ):
        super().__init__(nu_ffx, tau_ffx, tau_eps, n_ffx)

    def sample(self, n_samples: int,
               include_posterior: bool = False) -> dict[str, torch.Tensor]:
        okay = True
        ffx = self._sampleFfx()
        X = self._sampleFeatures(n_samples, ffx)
        X = self._correlate(X)
        X = self._addIntercept(X)
        eps = self._sampleError(n_samples)
        eta = X @ ffx
        y = eta + eps
        if y.std() > 1_000:
            okay = False
        rnv = self.relativeNoiseVariance(y, eta)
        out = {
            # data
            "X": X, # (n, d-1)
            "y": y, # (n,)
            # params
            "ffx": ffx, # (d,)
            "sigma_eps": self.sigma_eps, # (1,)
            # priors
            "nu_ffx": self.nu_ffx, # (d,)
            "tau_ffx": self.tau_ffx, # (d,)
            "tau_eps": self.tau_eps, # (1,)
            # misc
            "n": torch.tensor(n_samples), # (1,)
            "d": torch.tensor(self.d), # (1,)
            "rnv": rnv, # (1,)
            "okay": torch.tensor(okay),
        }
        if include_posterior:
            mu, Sigma, alpha, beta = self.posteriorParams(X, y)
            out.update({"ffx_mu": mu, "ffx_Sigma": Sigma,
                        "sigma_eps_alpha": alpha, "sigma_eps_beta": beta})
        return out

    # ----------------------------------------------------------------
    # analytical solution assuming Normal-IG-prior
    @cached_property
    def _priorPrecision(self) -> torch.Tensor:
        precision = (1. / self.tau_ffx).square()
        L_0 = torch.diag(precision)
        return L_0
    
    @cached_property
    def _priorAB(self) -> torch.Tensor:
        tbp = torch.tensor(2. / torch.pi)
        mean = self.tau_eps * tbp.sqrt()
        # variance = self._tau_eps.square() * (1. - tbp)
        a_0 = torch.tensor(3.)
        b_0 = mean * (a_0 - 1.) # torch.tensor(1.)
        return a_0, b_0
    
    def _posteriorPrecision(self, x: torch.Tensor) -> torch.Tensor:
        S = x.T @ x
        L_0 = self._priorPrecision
        L_n = L_0 + S
        return L_n

    def _posteriorCovariance(self, L_n: torch.Tensor) -> torch.Tensor:
        lower = torch.linalg.cholesky(L_n)
        S_n = torch.cholesky_inverse(lower)
        return S_n

    def _posteriorMean(self, x: torch.Tensor, y: torch.Tensor, S_n: torch.Tensor) -> torch.Tensor:
        mu_0, L_0 = self.nu_ffx, self._priorPrecision
        mu_n = S_n @ (x.T @ y + L_0 @ mu_0)
        return mu_n

    def _posteriorA(self, x: torch.Tensor) -> torch.Tensor:
        n = x.shape[0]
        a_0, _ = self._priorAB
        a_n = a_0 + n / 2.
        return a_n

    def _posteriorB(self, y: torch.Tensor, mu_n: torch.Tensor, L_n: torch.Tensor) -> torch.Tensor:
        y_inner = torch.dot(y, y)
        mu_0, L_0 = self.nu_ffx, self._priorPrecision
        mu_0_inner_scaled = torch.linalg.multi_dot([mu_0, L_0, mu_0])
        mu_n_inner_scaled = torch.linalg.multi_dot([mu_n, L_n, mu_n])
        _, b_0 = self._priorAB
        b_n = b_0 + (y_inner + mu_0_inner_scaled - mu_n_inner_scaled) / 2.
        return b_n

    def posteriorParams(self, x: torch.Tensor, y: torch.Tensor
                        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        L_n = self._posteriorPrecision(x)
        S_n = self._posteriorCovariance(L_n)
        mu_n = self._posteriorMean(x, y, S_n)
        a_n = self._posteriorA(x)
        b_n = self._posteriorB(y, mu_n, L_n)
        return mu_n, S_n, a_n, b_n

# -----------------------------------------------------------------------------
# MFX
class MixedEffects(Task):
    def __init__(self,
                 nu_ffx: torch.Tensor, # ffx prior means
                 tau_ffx: torch.Tensor, # ffx prior stds 
                 tau_eps: torch.Tensor, # noise prior std
                 tau_rfx: torch.Tensor, # rfx prior stds
                 n_ffx: int, # d
                 n_rfx: int, # q
                 n_groups: int, # m
                 n_obs: list[int], # n
                 features: torch.Tensor | None = None,
                 groups: torch.Tensor | None = None,
                 use_default: bool = False,
                 ):
        super().__init__(nu_ffx, tau_ffx, tau_eps, n_ffx,
                         features=features, use_default=use_default)
        assert len(n_obs) == n_groups, "mismatch between number of groups and individual observations"
        assert len(tau_rfx) == n_rfx, "mismatch between number of random effects and their prior stds"
        self.q = n_rfx
        self.m = n_groups 
        self.n_i = torch.tensor(n_obs) # per group
        self.groups = groups
        
        # rfx distribution
        self.tau_rfx = tau_rfx
        if (tau_rfx == 0).any():
            self.sigmas_rfx = torch.zeros_like(tau_rfx)
        else:
            self.sigmas_rfx = D.HalfNormal(self.tau_rfx).sample()

    def _sampleRfx(self) -> torch.Tensor:
        if self.q == 0:
            return torch.zeros((self.m, self.q))
        b = torch.randn((self.m, self.q))
        b = (b - b.mean(0, keepdim=True)) / b.std(0, keepdim=True)
        b = torch.where(b.isnan(), 0, b)
        b *= self.sigmas_rfx.unsqueeze(0)
        return b
    
    def solve(self, X, y, Z):
        S = torch.diag_embed(self.sigmas_rfx)
        E = torch.eye(len(Z)) * self.sigma_eps
        V = Z @ S @ Z.T + E
        L = torch.linalg.cholesky(V)                   # L lower triangular
        L_inv = torch.linalg.inv(L)                     # Inverse of L
        LX = L_inv @ X                                  # Multiply L_inv by X
        Ly = L_inv @ y                                  # Multiply L_inv by y
        ffx_hat = torch.linalg.pinv(LX) @ Ly 
        return ffx_hat
    
    def loglik(self, X, y, Z, ffx):
        S = torch.diag_embed(self.sigmas_rfx)
        E = torch.eye(len(Z)) * self.sigma_eps
        V = Z @ S @ Z.T + E
        L = torch.linalg.cholesky(V) 
        L_inv = torch.linalg.inv(L)
        det_V = torch.prod(torch.diagonal(L)) ** 2 + 1e-12
        eps = y - X @ ffx
        ll = torch.log(det_V) + eps @ L_inv @ eps
        return ll
        
        # loglike = np.log(np.linalg.det(V)) + (y - (X.T @ Beta)
        #                                   )@(np.linalg.inv(V)) @ (y-X.T@Beta).T # Log likelihood function

    def sample(self, include_posterior: bool = False) -> dict[str, torch.Tensor]:
        if include_posterior:
            raise NotImplementedError('posterior inference not implemented for MFX')
        okay = True
        n_samples = self.n_i.sum()
         
        # fixed effects and noise
        ffx = self._sampleFfx()
        if self.features is None:
            X = self._sampleFeatures(n_samples, ffx)
            X = self._correlate(X)
            X = self._addIntercept(X)
            # plotX(X[:, 1:])
        else:
            X = self.features
            # subsample observations
            n_i = torch.unique(self.groups, return_counts=True)[1]
            subjects = torch.randperm(len(n_i))[:self.m]
            self.n_i = n_i[subjects]
            n_samples = self.n_i.sum()
            subjects_mask = (self.groups.unsqueeze(-1) == subjects.unsqueeze(0)).any(-1)
            X = X[subjects_mask]
            # subsample features
            d = min(self.d, X.size(1))
            features = (torch.randperm(X.size(1) - 1)[:d-1]+1).tolist()
            features = [0] + features
            X = X[:, features]
            # optionally add more
            if d < self.d:
                X_ = self._sampleFeatures(n_samples, ffx[d-1:])
                X = torch.cat([X, X_], dim=1)
            
        eps = self._sampleError(n_samples)
        eta = X @ ffx 
        
        # random effects and target
        groups = torch.repeat_interleave(torch.arange(self.m), self.n_i) # (n,)
        rfx = self._sampleRfx() # (m, q)
        B = rfx[groups] # (n, q)
        Z = X[:,:self.q]
        y_hat = eta + (Z * B).sum(dim=-1)
        y = y_hat + eps
        snr = self.signalToNoiseRatio(y, eta)
        rnv = self.relativeNoiseVariance(y, y_hat)
        
        # check if dataset is within limits
        if eta.std() > self.original_limit:
            okay = False
        
        # Cov(mean Z, rfx), needed for standardization
        if self.q:
            weighted_rfx = Z.mean(0, keepdim=True) * rfx
            cov = fullCovary(weighted_rfx)
            cov_sum = cov.sum() - cov[0,0]
        else:
            cov_sum = torch.tensor(0.)
            
        # outputs
        out = {
            # data
            "X": X, # (n, d-1)
            "y": y, # (n,)
            "groups": groups, # (n,)
            # params
            "ffx": ffx, # (d,)
            "rfx": rfx, # (m, q)
            "sigmas_rfx": self.sigmas_rfx, # (q,)
            "sigma_eps": self.sigma_eps, # (1,)
            # priors
            "nu_ffx": self.nu_ffx, # (d,)
            "tau_ffx": self.tau_ffx, # (d,)
            "tau_rfx": self.tau_rfx, # (q,)
            "tau_eps": self.tau_eps, # (1,)
            # misc
            "m": torch.tensor(self.m), # (1,)
            "n": n_samples, # (1,)
            "n_i": self.n_i, # (m,)
            "d": torch.tensor(self.d), # (1,)
            "q": torch.tensor(self.q), # (1,)
            "cov_sum": cov_sum, # (1,)
            "snr": snr, # (1,)
            "rnv": rnv, # (1,)
            "okay": torch.tensor(okay), 
        }
        return out
 

# =============================================================================
if __name__ == "__main__":
    # seed = 1
    # torch.manual_seed(seed)
    n_ffx = 3
    nu = torch.tensor([0., 1., 1.])
    tau_beta = torch.tensor([50., 50., 50.])
    tau_eps = torch.tensor(50.)
    n_obs = 50
    
    # -------------------------------------------------------------------------
    # fixed effects
    print("fixed effects example\n----------------------------")
    fe = FixedEffects(nu, tau_beta, tau_eps, n_ffx=n_ffx)
    ds = fe.sample(n_obs, include_posterior=True)
    print(f"true ffx: {ds['ffx']}")
    print(f"true noise variance: {ds['sigma_eps'].square().item():.3f}")
    print(f"relative noise variance: {ds['rnv']:.2f}")
    
    
    # analytical posterior
    
    print(f"posterior mean: {ds['ffx_mu']}")
    print(f"posterior (margial) variance: {torch.diagonal(ds['ffx_Sigma'], dim1=-1, dim2=-2)}")
    print(f"posterior a: {ds['sigma_eps_alpha']:.1f}")
    print(f"posterior b: {ds['sigma_eps_beta']:.3f}")
    

    # noise variance
    eps = ds['y'] - ds['X'] @ ds['ffx']
    noise_var_ml = 1/(n_obs - n_ffx - 1) * torch.dot(eps, eps)
    expected_noise_var = torch.distributions.inverse_gamma.InverseGamma(ds['sigma_eps_alpha'], ds['sigma_eps_beta']).mean
    print(f"nominal error variance: {ds['sigma_eps'].square().item():.3f}")
    print(f"true error variance: {eps.var().item():.3f}")
    print(f"ML estimate: {noise_var_ml:.3f}")
    print(f"EAP estimate: {expected_noise_var:.3f}")


    # -------------------------------------------------------------------------
    # mixed effects
    print("\nmixed effects example\n----------------------------")

    n_obs = [50, 40, 30]
    n_rfx = 2
    n_groups = 3
    tau_rfx = torch.tensor([10., 20.])
    me = MixedEffects(nu, tau_beta, tau_eps, tau_rfx,
                      n_ffx=n_ffx, n_rfx=n_rfx, n_groups=n_groups, n_obs=n_obs)
    ds = me.sample()

    print(f"true ffx: {ds['ffx']}")
    print(f"true noise variance: {ds['sigma_eps']**2:.3f}")
    print(f"relative noise variance: {ds['rnv']:.2f}")
    print(f"random effects variances:\n{ds['sigmas_rfx']}")

