import torch
from torch import distributions as D
from metabeta.utils import dampen, maskedStd, weightedMean
from metabeta.evaluation.resampling import replace, powersample 

def getImportanceWeights(log_likelihood: torch.Tensor,
                         log_prior: torch.Tensor,
                         log_q: torch.Tensor,
                         constrain: bool = True) -> dict[str, torch.Tensor]:
    log_w = log_likelihood + log_prior - log_q
    if constrain:
        log_w = dampen(log_w, p=0.75)
    log_w_max = torch.quantile(
        log_w, 0.99, dim=-1).unsqueeze(-1)
    log_w = log_w.clamp(max=log_w_max)
    log_w = (log_w - log_w_max)
    w = log_w.exp()
    w = w / w.mean(dim=-1, keepdim=True)
    n_eff = w.sum(-1).square() / (w.square().sum(-1) + 1e-12)
    sample_efficiency = n_eff / w.shape[-1]
    
    return {'weights': w,
            'n_eff': n_eff,
            'sample_efficiency': sample_efficiency}


# =============================================================================
# Fixed Effects
# -----------------------------------------------------------------------------
class ImportanceFFX:
    def __init__(self, data: dict[str, torch.Tensor]):
        self.nu = data['nu_ffx']
        self.tau = data['tau_ffx']
        self.tau_eps = data['tau_eps'].unsqueeze(-1)
        self.mask_d = data['mask_d']
        self.y = data['y'].unsqueeze(-1)
        self.X = data['X']
        self.n = data['n'].unsqueeze(-1)
        
        
    def logPriorBeta(self, beta: torch.Tensor) -> torch.Tensor:
        # beta: (b, d, s)
        assert beta.dim() == 3
        nu = self.nu.unsqueeze(-1)
        tau = self.tau.unsqueeze(-1)
        mask = self.mask_d.unsqueeze(-1)
        lp = D.Normal(nu, tau + 1e-12).log_prob(beta)
        lp = lp * mask
        return lp.sum(dim=1) # (b, s)
    
    
    def logPriorSigma(self, sigma: torch.Tensor) -> torch.Tensor:
        # sigma: (b, s)
        assert sigma.dim() == 2
        tau = self.tau_eps
        lp = D.HalfNormal(tau).log_prob(sigma)
        return lp # (b, s)
    
    
    def logLikelihood(self, beta: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
        # b batch size, n subjects, d features, s samples 
        # beta: (b, d, s), sigma: (b, s)
        
        # compute sum of squared residuals
        mu = torch.einsum('bnd,bds->bns', self.X, beta)
        ssr = (self.y - mu).square().sum(dim=1) # (b, s)
        # ssr = (self.y - mu).square().mean(dim=1) # (b, s)
    
        # Compute log likelihood per batch
        ll = (
            - 0.5 * self.n * torch.tensor(2 * torch.pi).log() 
            - self.n * sigma.log()
            - 0.5 * ssr / sigma.square()
        )
        return ll # (b, s)
    
    
    def getTerms(self, beta: torch.Tensor, sigma: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        log_prior_beta = self.logPriorBeta(beta)
        log_prior_sigma = self.logPriorSigma(sigma)
        log_likelihood = self.logLikelihood(beta, sigma)
        return log_prior_beta + log_prior_sigma, log_likelihood
    
    
    def __call__(self, proposed: dict[str, torch.Tensor],
                 resample: bool = False, upsample: bool = False,) -> dict[str, torch.Tensor]:
        samples = proposed['global']['samples']
        log_q = proposed['global']['log_prob']
        beta, sigma = samples[:, :-1], samples[:, -1] + 1e-12
        
        # estimate sigma_ as a dependent surrogate
        # sigma_ = (self.y - torch.einsum('bnd,bds->bns', self.X, beta)).std(1)
        
        # Importance Sampling for beta
        log_prior, log_likelihood = self.getTerms(beta, sigma)
        out = getImportanceWeights(log_likelihood, log_prior, log_q)
        out['weights'] = out['weights'].unsqueeze(1).expand(*samples.shape).clone()
        
        # Importance Sampling for sigma
        beta_ = (beta * out['weights'][:, :-1]).sum(-1, keepdim=True) / beta.shape[-1]
        beta_ = beta_.expand(*beta.shape)
        log_prior_, log_likelihood_ = self.getTerms(beta_, sigma)
        out_ = getImportanceWeights(log_likelihood_, log_prior_, log_q)
        out['weights'][:, -1] = out_['weights']
        
        # unidimensional resampling
        if resample:
            resamples = replace(samples, out['weights'], t=200)
            out.pop('weights')
            
            # upsample with Yeo-Johnson resp. Box-Cox
            # issue: log_prob would need to be re-evaluated for new IS
            if upsample:
                beta, sigma = resamples[:, :-1], resamples[:, -1]
                beta, _ = powersample(beta, t=1000, method='yeo-johnson')
                sigma, _ = powersample(sigma.unsqueeze(1), t=1000, method='box-cox')
                resamples = torch.cat([beta, sigma], dim=1)
                out['global'] = {'samples': resamples}
            
            out['global'] = {'samples': resamples}
            
        # update proposed
        proposed['global'].update(**out)
        return proposed


# =============================================================================
# Mixed Effects
# -----------------------------------------------------------------------------
class ImportanceLocal:
    def __init__(self, data: dict[str, torch.Tensor]):
        self.nu_ffx = data['nu_ffx']
        self.tau_ffx = data['tau_ffx']
        self.tau_rfx = data['tau_rfx']
        self.tau_eps = data['tau_eps']
        self.mask_d = data['mask_d']
        self.mask_n = data['mask_n']
        self.mask_m = data['mask_m']
        self.y = data['y'].unsqueeze(-1)
        self.X = data['X']
        self.Z = data['Z']
        self.n = data['n'].unsqueeze(-1)
        self.n_i = data['n_i'].unsqueeze(-1)
        self.q = data['q'].unsqueeze(1)
        self.max_d = data['d'].max()
        
    def logPriorBeta(self, beta: torch.Tensor) -> torch.Tensor:
        # beta: (b, d)
        nu, tau, mask = self.nu_ffx, self.tau_ffx, self.mask_d
        lp = D.Normal(nu, tau + 1e-12).log_prob(beta)
        lp = (lp * mask).sum(dim=1) # (b,)
        return lp
        
    def logPriorSigmasRfx(self, sigmas_rfx: torch.Tensor) -> torch.Tensor:
        # sigmas_rfx: (b, q)
        tau, mask = (self.tau_rfx + 1e-12), self.q
        lp = D.HalfNormal(tau).log_prob(sigmas_rfx)
        lp = (lp * mask).sum(dim=1) # (b,)
        return lp 
    
    def logPriorNoise(self, sigma: torch.Tensor) -> torch.Tensor:
        # sigma: (b,)
        tau = self.tau_eps
        lp = D.HalfNormal(tau).log_prob(sigma)
        return lp

    def logLikelihoodCond(self,
                          ffx: torch.Tensor, # (b, d)
                          sigma: torch.Tensor, # (b)
                          rfx: torch.Tensor, # (b, m, q, s)
                          ) -> torch.Tensor:
        # conditional likelihood for rfx IS
        mu_g = torch.einsum('bmnd,bd->bmn', self.X, ffx).unsqueeze(-1)
        mu_l = torch.einsum('bmnd,bmds->bmns', self.Z, rfx)
        eps = self.y - mu_g - mu_l # (b, m, n, s)
        ssr = eps.square().sum(dim=2) # (b, m, s)
        sigma_ = sigma.view(-1, 1, 1)
        ll = (
            # - 0.5 * self.n_i * torch.tensor(2 * torch.pi).log() 
            - self.n_i * sigma_.log()
            - 0.5 * ssr / sigma_.square()
        ) # (b, m, s)
        return ll 
        
    def logLikelihoodRfx(self,
                         rfx: torch.Tensor, # (b, m, q, s)
                         sigmas_rfx: torch.Tensor, # (b, q)
                         ) -> torch.Tensor:
        S = sigmas_rfx.unsqueeze(1).unsqueeze(-1) + 1e-12 # (b, 1, q, 1)
        means = torch.zeros_like(S)
        ll = D.Normal(means, S).log_prob(rfx) # (b, m, q, s) 
        mask = (rfx != 0)
        ll = (ll * mask).sum(dim=-2) # (b, m, s)
        return ll
    
    def __call__(self, proposed: dict[str, torch.Tensor],
                 resample: bool = False, upsample: bool = False,
                 ) -> dict[str, torch.Tensor]:
        # unpack
        log_q = proposed['local']['log_prob'].clone()
        rfx = proposed['local']['samples'].clone()
        samples_g = proposed['global']['samples'].clone()
        weights_g = proposed['global'].get('weights', None)
        mean_g = weightedMean(samples_g, weights_g)
        ffx = mean_g[:, :self.max_d]
        sigmas_rfx = mean_g[:, self.max_d:-1]
        sigma_eps = mean_g[:, -1]
        
        # priors
        log_prior_beta = self.logPriorBeta(ffx)
        log_prior_noise = self.logPriorNoise(sigma_eps)
        log_prior_sigmas_rfx = self.logPriorSigmasRfx(sigmas_rfx)
        log_prior = (log_prior_beta + log_prior_noise + log_prior_sigmas_rfx)
        
        # likelihoods 
        log_likelihood_cond = self.logLikelihoodCond(ffx, sigma_eps, rfx)
        log_likelihood_rfx = self.logLikelihoodRfx(rfx, sigmas_rfx)
        log_likelihood = log_likelihood_cond + log_likelihood_rfx
        
        # importance sampling
        factor = 5.
        out = getImportanceWeights(
            log_likelihood, log_prior.view(-1,1,1), factor * log_q)
        out['weights'] = out['weights'].unsqueeze(-2).expand(*rfx.shape).clone()
        
        # resampling
        if resample or upsample:
            resamples = rfx.clone()
            if resample:
                resamples = replace(resamples, out['weights'], t=200)
                out.pop('weights', None)
            if upsample:
                b,m,d,s = resamples.shape
                resamples = resamples.reshape(b,m*d,s)
                resamples, _ = powersample(
                    resamples, t=1000, method='yeo-johnson')
                resamples = resamples.reshape(b,m,d,1000)
                out.pop('weights', None)
            out['samples'] = resamples
        
        # finalize
        proposed['local'].update(**out)
        return proposed
        
    
# -----------------------------------------------------------------------------
class ImportanceGlobal:
    def __init__(self, data: dict[str, torch.Tensor]):
        self.nu_ffx = data['nu_ffx'].unsqueeze(-1)
        self.tau_ffx = data['tau_ffx'].unsqueeze(-1)
        self.tau_rfx = data['tau_rfx']
        self.tau_eps = data['tau_eps'].unsqueeze(-1)
        self.mask_d = data['mask_d'].unsqueeze(-1)
        self.mask_n = data['mask_n']
        self.mask_m = data['mask_m']
        self.y = data['y'].unsqueeze(-1)
        self.X = data['X']
        self.Z = data['Z']
        self.n = data['n'].unsqueeze(-1)
        self.n_i = data['n_i'].unsqueeze(-1)
        self.q = data['q']
        self.max_d = data['d'].max()
        
    def logPriorBeta(self, beta: torch.Tensor) -> torch.Tensor:
        # beta: (b, d, s)
        nu, tau, mask = self.nu_ffx, self.tau_ffx, self.mask_d
        lp = D.Normal(nu, tau + 1e-12).log_prob(beta)
        lp = (lp * mask).sum(dim=1) 
        return lp # (b, s)
    
    def logPriorNoise(self, sigma: torch.Tensor) -> torch.Tensor:
        # sigma: (b, s)
        tau = self.tau_eps
        lp = D.HalfNormal(tau).log_prob(sigma)
        return lp # (b, s)
    
    def logPriorSigmasRfx(self, sigmas_rfx: torch.Tensor) -> torch.Tensor:
        # sigmas_rfx: (b, q, s)
        tau = (self.tau_rfx + 1e-12).unsqueeze(-1)
        mask = self.q.view(-1, 1, 1)
        lp = D.HalfNormal(tau).log_prob(sigmas_rfx)
        lp = (lp * mask).sum(dim=1) # (b,s)
        return lp 
    
    def logLikelihoodCond(self,
                          ffx: torch.Tensor, # (b, d, s)
                          sigma: torch.Tensor, # (b, s)
                          rfx: torch.Tensor, # (b, m, q)
                          ) -> torch.Tensor:
        mu_g = torch.einsum('bmnd,bds->bmns', self.X, ffx)
        mu_l = torch.einsum('bmnq,bmq->bmn', self.Z, rfx).unsqueeze(-1)
        eps = self.y - mu_g - mu_l # (b, m, n, s)
        ssr = eps.square().sum((1,2)) # (b, s)
        ll = (
            # - 0.5 * self.n * torch.tensor(2 * torch.pi).log() 
            - self.n * sigma.log()
            - 0.5 * ssr / sigma.square()
            ) 
        return ll # (b, s)
    
        
    def logLikelihoodRfx(self,
                         rfx: torch.Tensor, # (b, m, q)
                         sigmas_rfx: torch.Tensor, # (b, q, s)
                         ) -> torch.Tensor:
        rfx = rfx.unsqueeze(-1) # (b, m, q, 1)
        S = sigmas_rfx.unsqueeze(1) + 1e-12 # (b, 1, q, s)
        means = torch.zeros_like(S)
        ll = D.Normal(means, S).log_prob(rfx) # (b, m, q, s) 
        mask = (rfx != 0)
        ll = (ll * mask).sum(dim=(1,2)) # (b, s)
        return ll

    
    def solve(self):
        b, m, n, d = self.X.shape
        X = self.X.view(b, m*n, -1)
        y = self.y.view(b, -1)
        hat = torch.linalg.pinv(X)
        ffx_hat = torch.einsum('bdn,bn->bd', hat, y)
        return ffx_hat
     
    
    def __call__(self, proposed: dict[str, torch.Tensor],
                 resample: bool = False, upsample: bool = False,
                 ) -> dict[str, torch.Tensor]:
        # unpack
        out = {}
        log_q = proposed['global']['log_prob'].clone()
        samples_g = proposed['global']['samples'].clone()
        ffx = samples_g[:, :self.max_d]
        sigmas_rfx = samples_g[:, self.max_d:-1] + 1e-12
        sigma_eps = samples_g[:, -1] + 1e-12
        samples_l = proposed['local']['samples'].clone()
        weights_l = proposed['local'].get('weights', None)
        rfx = weightedMean(samples_l, weights_l)
        
        # # algebraic solution
        # ffx_hat = self.solve()
        # ffx = ffx_hat.unsqueeze(-1).expand(*ffx.shape)
        
        # pseudo importance sampling for sigma_eps
        mu_g = torch.einsum('bmnd,bds->bmns', self.X, ffx)
        mu_l = torch.einsum('bmnq,bmq->bmn', self.Z, rfx)
        eps = self.y - mu_g - mu_l.unsqueeze(-1) # (b,m,n,s)
        sigma_eps_ = maskedStd(eps, (1,2), self.mask_n.unsqueeze(-1)).squeeze()
        nu_s = sigma_eps_.min(-1, keepdim=True)[0]
        deltas = (sigma_eps - nu_s).square()
        log_weights = -(1+deltas).log()
        weights_s = log_weights.exp()
        weights_s = weights_s / weights_s.sum(-1, keepdim=True) * weights_s.shape[-1]
        
        # alternatively calibrate sigma_eps
        # tau_s = sigma_eps_.std(-1, keepdim=True)
        # sigma_eps_new = sigma_eps - sigma_eps.mean(-1, keepdim=True)
        # sigma_eps = (sigma_eps_new + nu_s).abs()
        
        # calibrate sigmas_rfx
        # sigmas_rfx_ = maskedStd(samples_l, 1, self.mask_m.unsqueeze(-1).unsqueeze(-1))
        # nu = sigmas_rfx_.mean(-1)
        # # tau = sigmas_rfx_.std(-1)
        # sigmas_rfx = sigmas_rfx - sigmas_rfx.mean(-1, keepdim=True)
        # sigmas_rfx = (sigmas_rfx + nu).abs()
        out['samples'] = torch.cat([ffx, sigmas_rfx, sigma_eps.unsqueeze(1)], dim=1)
        
        # components
        log_prior_beta = self.logPriorBeta(ffx)
        log_prior_noise = self.logPriorNoise(nu_s)
        log_prior_sigmas_rfx = self.logPriorSigmasRfx(sigmas_rfx)
        log_prior = (log_prior_beta + log_prior_noise)
        log_likelihood = self.logLikelihoodCond(ffx, nu_s, rfx)
        log_likelihood_rfx = self.logLikelihoodRfx(rfx, sigmas_rfx)
        
        # importance sampling (ffx)
        factor = 3.
        is_results = getImportanceWeights(
            log_likelihood, log_prior, factor * log_q)
        out.update(**is_results)
        out['weights'] = out['weights'].unsqueeze(-2).expand(*samples_g.shape).clone()
        out['weights'][:, -1] = weights_s
        
        # importance sampling (sigmas_rfx)
        factor = 0.1
        out_ = getImportanceWeights(
            log_likelihood_rfx, factor * log_prior_sigmas_rfx, factor * log_q)
        out_['weights'] = out_['weights'].unsqueeze(-2).expand(*sigmas_rfx.shape).clone()
        out['weights'][:, self.max_d:-1] = out_['weights']
        
        # resampling
        if resample or upsample:
            resamples = samples_g.clone()
            if resample:
                resamples = replace(resamples, out['weights'], t=200)
                out.pop('weights', None)
            if upsample:
                ffx, sigmas = resamples[:, :self.max_d], resamples[:, self.max_d:] + 1e-12
                ffx, _ = powersample(ffx, t=1000, method='yeo-johnson')
                sigmas, _ = powersample(sigmas, t=1000, method='box-cox')
                resamples = torch.cat([ffx, sigmas], dim=1)
                out.pop('weights', None)
            out['samples'] = resamples
        
        # finalize
        proposed['global'].update(**out)
        return proposed
        
    
# # importance sampling (sigma intercept)
# sigma0 = maskedStd(rfx[..., 0], 1, self.mask_m)
# deltas = (sigmas_rfx[:, 0] - sigma0).square()
# log_weights = -(1+deltas).log()
# weights = log_weights.exp()
# weights = weights / weights.sum(-1, keepdim=True) * weights.shape[-1]
# out['weights'][:, self.max_d] = weights

# def marginalLogLikelihoodLooped(beta: torch.Tensor, sigmas: torch.Tensor,
#                                 batch: dict[str, torch.Tensor]) -> torch.Tensor:
#     y, X, Z = batch['y'], batch['X'], batch['Z'], 
#     ds, qs = batch['d'], batch['q']
#     b, s = beta.shape[0], beta.shape[-1]
#     mask = batch['mask_n']
#     sigma = sigmas[:, 0]
#     S = torch.diag_embed(sigmas[:, 1:].permute(0,2,1)).permute(0,2,3,1).square()
#     ll = torch.zeros(b,s)
#     for i in range(b):
#         print('batch', i)
#         y_i = y[i, mask[i]].permute(1,0)
#         X_i = X[i, mask[i]][..., :ds[i]+1]
#         Z_i = Z[i, mask[i]][..., :qs[i]]
#         S_i = S[i, :qs[i], :qs[i]]
#         V_i = torch.einsum('nq,qqs->nqs', Z_i, S_i)
#         V_i = torch.einsum('nqs,mq->nms', V_i, Z_i)
#         I_i = torch.eye(len(V_i)).unsqueeze(-1)
#         V_i +=  I_i * (sigma[i].square() + 1e-4)
#         V_i = V_i.permute(-1, 0, 1)
#         mu_i = torch.einsum('nd,ds->ns', X_i, beta[i]).view(s,-1)
#         mvn_i = D.MultivariateNormal(mu_i, covariance_matrix=V_i)
#         ll[i] = mvn_i.log_prob(y_i)
#     return ll # (b, s)


