import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import spearmanr, pearsonr
import torch
from torch import nn
from metabeta.utils import (
    maskedMean, maskedStd, batchCovary,
    maskedLog, maskedSoftplus, maskedInverseSoftplus, dampen, check
    )
from metabeta.models.summarizers import Summarizer, DeepSet, PoolFormer
from metabeta.models.posteriors import Posterior, CouplingPosterior
mse = nn.MSELoss()

summary_defaults = {
    'type': 'poolformer',
    'd_model': 128, 
    'blocks': 4, 
    'd_ff': 128,
    'depth': 1, # only affects non-sparse
    'd_output': 64,
    'heads': 8,
    'dropout': 0.01, # 0.05 for non-sparse
    'activation': 'GELU',
    'sparse': True,
    }

posterior_defaults = {
    'type': 'flow-affine',
    'flows': 4, 
    'd_ff': 128,
    'depth': 3,
    'dropout': 0.01,
    'activation': 'ReLU',
    }

model_defaults = {
    'type': 'mfx',
    'seed': 42,
    'd': 5,
    'q': 2,
    'tag': ''
    }

# # -----------------------------------------------------------------------------
# torch.set_printoptions(sci_mode=False)

# base approximator
class Approximator(nn.Module):
    def __init__(self,
                 constrain: bool = True,
                 use_standardization: bool = True,
                 ):
        super().__init__()
        self.constrain = constrain
        self.use_standardization = use_standardization
        self.stats = {}
    
    @staticmethod
    def modelID(s_dict: dict, p_dict: dict, m_dict: dict) -> str:
        prefix = ''
        suffix = ''
        if s_dict['sparse']:
            prefix = 'sparse-'
        if m_dict['tag']:
            suffix = '-' + m_dict['tag']
        summary = f"{prefix}{s_dict['type']}-{s_dict['blocks']}-{s_dict['d_model']}-{s_dict['d_ff']}*{s_dict['depth']}-{s_dict['d_output']}-{s_dict['heads']}-{s_dict['activation']}-{s_dict['dropout']}"
        posterior = f"{p_dict['type']}-{p_dict['flows']}-{p_dict['d_ff']}*{p_dict['depth']}-{p_dict['activation']}-{p_dict['dropout']}"
        return f"{m_dict['type']}-d={m_dict['d']}-q={m_dict['q']}-{summary}-+-{posterior}-seed={m_dict['seed']}{suffix}"

    @property
    def device(self):
        return next(self.parameters()).device

    def inputs(self, data: dict[str, torch.Tensor]) -> torch.Tensor:
        ''' prepare input tensor for the summary network '''
        raise NotImplementedError 
        
    def targets(self, data: dict[str, torch.Tensor]):
        ''' prepare target tensor for the posterior network'''
        raise NotImplementedError

    def forward(self, data: dict[str, torch.Tensor], sample: bool = False):
        raise NotImplementedError
    
    def standardize(self, x: torch.Tensor, name: str,
                    mask: torch.Tensor | None = None) -> torch.Tensor:
        ''' z-standardization specific for each dataset '''
        # TODO: exclude categorial variables from standardization
        dim = tuple(range(1, x.dim() - 1))
        if mask is not None:
            mean = maskedMean(x, dim, mask=mask)
            std = maskedStd(x, dim, mask=mask, mean=mean) + 1e-12
        else:
            mean = x.mean(dim, keepdim=True)
            std = x.std(dim, keepdim=True) + 1e-12
        # std = std.clamp(max=300)
        self.stats[name] = {'mean': mean, 'std': std}
        out = (x-mean) / std
        if mask is not None:
            out *= mask
        return out
    
    def unpackMoment(self,
                     names_list: list[str],
                     moment: str,
                     device: str|None = None) -> dict[str, torch.Tensor]:
        if device is None:
            return {name: self.stats[name][moment] for name in names_list}
        return {name: self.stats[name][moment].to(device) for name in names_list}
   
    def unpackMean(self, names_list, device=None):
        return self.unpackMoment(names_list, 'mean', device)
        
    def unpackStd(self, names_list, device=None):
        return self.unpackMoment(names_list, 'std', device)
    
    # def runningStandardize(self, x: torch.Tensor, name: str) -> torch.Tensor:
    #     ''' z-standardization with optional momentum-based updates of mean/std '''
    #     mask = (x != 0)
    #     if self.training:
    #         t = self.stats['momentum']
    #         mean = maskedMean(x, mask=mask)
    #         std = maskedStd(x, mask=mask, mean=mean) + 1e-12
    #         if name in self.stats:
    #             mean = t * self.stats[name]['mean'] + (1-t) * mean
    #             std = t * self.stats[name]['std'] + (1-t) * std
    #         self.stats[name] = {'mean': mean, 'std': std}
    #     else:
    #         assert name in self.stats, f'Model in evaluation mode but has no stats for {name}'
    #         mean, std = self.stats[name].values()
    #         mean, std = mean.to(x.device), std.to(x.device)
    #     out = (x-mean)/std
    #     out *= mask
    #     return out
    
    # def unstandardize(self, samples: torch.Tensor, name: str) -> torch.Tensor:
    #     mean, std = self.stats[name].values()
    #     mean, std = mean.to(samples.device).unsqueeze(-1), std.to(samples.device).unsqueeze(-1)
    #     mask = (samples != 0)
    #     out = samples * std + mean
    #     out *= mask
    #     return out
    
    # def quantize(self, x: torch.Tensor) -> torch.Tensor:
    #     ''' assign column-wise quantiles to x '''
    #     b, n, d = x.shape
    #     sorted_indices = x.argsort(dim=1)
    #     ranks = torch.arange(n, device=x.device).view(1,n,1).expand(b, n, d).float()
    #     quantized = torch.empty_like(x, dtype=torch.float)
    #     quantized = quantized.scatter_(1, sorted_indices, ranks)
    #     return quantized / (n - 1)
        
    # def compress(self, x: torch.Tensor) -> torch.Tensor:
    #     ''' decompose x using SVD and keep the first half
    #         of the projected columns '''
    #     k = x.shape[-1] // 2 + 1
    #     U, S, V = torch.linalg.svd(x, full_matrices=False)
    #     U = U[:, :, :k]
    #     S = S[:, :k].unsqueeze(-1)
    #     V = V[:, :k, :k]
    #     compressed = torch.einsum('bnk,bkd->bnd', U, S * V)
    #     return compressed
    

    def moments(self, proposed: torch.Tensor | dict[str, torch.Tensor]
                ) -> tuple[torch.Tensor, torch.Tensor]:
        ''' wrapper for location and scale of the posterior '''
        if 'samples' in proposed:
            return self.posterior.getLocScale(proposed)
        
    def quantiles(self, proposed: torch.Tensor | dict[str, torch.Tensor],
                  roots: list = [.025, .975], calibrate: bool = False) -> torch.Tensor:
        ''' wrapper for desired quantiles of the posterior '''
        if 'samples' in proposed:
            samples = proposed['samples'].clone()
            quantiles = self.posterior.getQuantiles(samples, roots, calibrate)
            return quantiles
        
    def ranks(self,
              proposed: torch.Tensor | dict[str, torch.Tensor],
              targets: torch.Tensor) -> torch.Tensor:
        return self.posterior.getRanks(proposed, targets)

    def examples(self,
                 indices: list[int],
                 batch: dict[str, torch.Tensor],
                 proposed: torch.Tensor | dict[str, torch.Tensor],
                 printer: callable, console_width: int) -> None:
        ''' print true and predicted targets for indexed datasets in batch '''
        ffx = batch['ffx']
        d = ffx.shape[1]
        sigmas_rfx = batch['sigmas_rfx'] if 'sigmas_rfx' in batch else None
        ffx_a = batch['analytical']['ffx']['mu'] if 'analytical' in batch else None
        loc_ = None
        if isinstance(proposed, dict) and 'global' in proposed:
            if 'local' in proposed:
                loc_, _ = self.posterior.getLocScale(proposed['local'])
            proposed = proposed['global']
        loc, scale = self.moments(proposed)
        for i in indices:
            n_i = int(batch['n'][i])
            sigma_i = float(batch['sigma_error'][i])
            mask = (ffx[i] != 0.)
            beta_i = ffx[i, mask].detach().numpy()
            mean_i = loc[i, :d][mask].detach().numpy()
            std_i = scale[i, :d][mask].detach().numpy()
            sigma_i_ = loc[i, d].detach().numpy()
            printer(f"\n{console_width * '-'}")
            printer(f"n={n_i}, sigma={sigma_i:.2f}, predicted={sigma_i_:.2f}")
            # FFX
            printer(f"True FFX   : {beta_i}")
            if ffx_a is not None:
                printer(f"Optimal FFX: {ffx_a[i, mask].detach().numpy()}")
            printer(f"Mean FFX   : {mean_i}")
            printer(f"SD FFX     : {std_i}")
            # Variances
            if sigmas_rfx is not None:
                mask_rfx = (sigmas_rfx[i] != 0.)
                sigmas_rfx_i = sigmas_rfx[i][mask_rfx].detach().numpy()
                sigmas_rfx_i_ = loc[i, d+1:][mask_rfx].detach().numpy()
                printer(f"True sigmas RFX: {sigmas_rfx_i}")
                printer(f"Mean sigmas RFX: {sigmas_rfx_i_}")
            # Random Effects
            # if loc_ is not None:
            #     rfx_i = batch['rfx'][i][..., mask_rfx][:5, 0].detach().numpy()
            #     rfx_i_ = loc_[i][:, mask_rfx][:5, 0].detach().numpy()
            #     printer(f"True random intercept (first 5): {rfx_i}")
            #     printer(f"Mean random intercept (first 5): {rfx_i_}")
            printer(f"{console_width * '-'}\n")
            
    @staticmethod
    def plotRecovery(targets: torch.Tensor,
                     names: list[str],
                     means: torch.Tensor,
                     stds: torch.Tensor | None = None,
                     show_error: bool = False,
                     color: str = 'darkgreen',
                     alpha: float = 0.3,
                     return_stats: bool = True) -> None | tuple[float, float]:
        ''' plot true targets against posterior means for entire batch '''
        # get sizes
        assert means.shape[-1] == len(names), "shape mismatch"
        D = len(names)
        if targets.dim() == 3:
            targets = targets.view(-1, D)
            means = means.view(-1, D)
            stds = stds.view(-1, D)
        mask = (targets != 0.)
        w = int(torch.tensor(D).sqrt().ceil())
        
        # init figure
        fig, axs = plt.subplots(figsize=(8 * w, 7 * w), ncols=w, nrows=w)
        if w > 1:
            axs = axs.flatten()
        
        # init stats
        RMSE = 0.
        R = 0.
        denom = D
        
        # make subplots
        for i in range(D):
            ax = axs[i] if w > 1 else axs
            mask_i = mask[..., i]
            targets_i = targets[mask_i, i]
            mean_i = means[mask_i, i].detach()
            
            # skip empty target
            if mask_i.sum() == 0:
                axs[i].set_visible(False)
                denom -= 1
                continue
            
            # compute stats
            r = pearsonr(targets_i, mean_i)[0]
            # r = spearmanr(targets_i, mean_i)[0]
            R += r
            bias = (targets_i - mean_i).mean()
            rmse = mse(targets_i, mean_i).sqrt()
            RMSE += rmse
            
            # subplot
            ax.set_axisbelow(True)
            ax.grid(True)
            min_val = min(mean_i.min(), targets_i.min()).floor()
            max_val = max(mean_i.max(), targets_i.max()).ceil()
            ax.plot([min_val, max_val],
                    [min_val, max_val],
                    '--', lw=2, zorder=1, color='grey') # diagline
            ax.scatter(targets_i, mean_i,
                       alpha=alpha, color=color, label=names[i])
            if show_error:
                std_i = stds[mask_i, i]
                ax.errorbar(targets_i, mean_i, yerr=std_i,
                            fmt='', alpha=0.3, color='grey',
                            capsize=0, linestyle='none')
            ax.text(
                0.75, 0.1,
                f'r = {r.item():.3f}\nBias = {bias.item():.3f}\nRMSE = {rmse.item():.3f}',
                transform=ax.transAxes,
                ha='center', va='bottom',
                fontsize=16,
                bbox=dict(boxstyle='round',
                          facecolor=(1, 1, 1, 0.7),
                          edgecolor=(0, 0, 0, alpha),
                          ),
            )
            ax.set_xlabel('true', fontsize=20)
            ax.set_ylabel('estimated', fontsize=20)
            ax.legend()
        
        # skip remaining empty subplots
        for i in range(D, w*w):
            axs[i].set_visible(False)
        fig.tight_layout()
        
        # optionally return average statistics
        if return_stats:
            return RMSE/denom, R/denom
        
    
    
    def plotRecoveryGrouped(self,
                            targets: list[torch.Tensor],
                            names: list[list[str]],
                            means: list[torch.Tensor],
                            titles: list[str] = None,
                            marker: str = 'o',
                            alpha: float = 0.2) -> None | tuple[float, float]:
        N = len(names)
        fig, axs = plt.subplots(figsize=(7*N, 7), ncols=N, nrows=1, dpi=300)
        i = 0
        for _targets, _names, _means, title, ax in zip(targets, names, means, titles, axs):
            self._plotRecoveryGrouped(_targets, _names, _means,
                                      ax=ax, title=title, marker=marker,
                                      alpha=alpha, show_y=(i==0))
            i += 1
        fig.tight_layout()
        
        
    @staticmethod
    def _plotRecoveryGrouped(targets: torch.Tensor,
                            names: list[str],
                            means: torch.Tensor,
                            stds: torch.Tensor | None = None,
                            show_error: bool = False,
                            ax: plt.axes = None,
                            title: str = '',
                            marker: str = 'o',
                            colors: list[np.ndarray] | None = None,
                            alpha: float = 0.2,
                            show_y: bool = True) -> None | tuple[float, float]:
        ''' plot true targets against posterior means for entire batch '''
        # get sizes
        assert means.shape[-1] == len(names), "shape mismatch"
        if colors is not None:
            assert len(colors) >= len(names), "not enough colors provided"
        D = len(names)
        if targets.dim() == 3:
            targets = targets.view(-1, D)
            means = means.view(-1, D)
            stds = stds.view(-1, D) if stds is not None else None
        mask = (targets != 0.)
        
        # init figure
        ax.set_title(title, fontsize=30, pad=15)
        ax.set_axisbelow(True)
        ax.grid(True)
        min_val = min(means.min(), targets.min()).floor()
        max_val = max(means.max(), targets.max()).ceil()
        addon = 4 if min_val < 0 else 1
        ax.set_xlim([min_val - addon, max_val + addon], auto=False)
        ax.set_ylim([min_val - addon, max_val + addon], auto=False)
        ax.plot([min_val, max_val],
                [min_val, max_val],
                '--', lw=2, zorder=1, color='grey') # diagline
        
        # init stats
        RMSE = 0.
        R = 0.
        Bias = 0.
        denom = D
        
        # overlay plots
        for i in range(D):
            mask_i = mask[..., i]
            targets_i = targets[mask_i, i]
            mean_i = means[mask_i, i].detach()
            # skip empty target
            if mask_i.sum() == 0:
                denom -= 1
                continue
            
            # compute stats
            r = pearsonr(targets_i, mean_i)[0]
            # r = spearmanr(targets_i, mean_i)[0]
            R += r
            bias = (targets_i - mean_i).mean()
            Bias += bias
            rmse = mse(targets_i, mean_i).sqrt()
            RMSE += rmse
            
            # subplot
            if colors is not None:
                ax.scatter(targets_i, mean_i, marker=marker,
                           alpha=alpha, color=colors[i], label=names[i])
            else:
                ax.scatter(targets_i, mean_i, marker=marker,
                           alpha=alpha, label=names[i])
            if show_error:
                std_i = stds[mask_i, i]
                ax.errorbar(targets_i, mean_i, yerr=std_i,
                            fmt='', alpha=0.3, color='grey',
                            capsize=0, linestyle='none')
        
        # add stats
        rmse = RMSE/denom
        bias = Bias/denom
        r = R/denom
        ax.text(
            0.7, 0.1,
            f'r = {r.item():.3f}\nBias = {bias.item():.3f}\nRMSE = {rmse.item():.3f}',
            transform=ax.transAxes,
            ha='center', va='bottom',
            fontsize=22,
            bbox=dict(boxstyle='round',
                      facecolor=(1, 1, 1, 0.7),
                      edgecolor=(0, 0, 0, alpha),
                      ),
        )
        ax.set_xlabel('true', fontsize=26, labelpad=10)
        if show_y:
            ax.set_ylabel('estimated', fontsize=26, labelpad=10) 
        ax.legend(fontsize=22, markerscale=2.5, loc='upper left')


# -----------------------------------------------------------------------------
# fixed effects
class ApproximatorFFX(Approximator):
    def __init__(self,
                 summarizer: Summarizer,
                 posterior: Posterior,
                 model_id: str,
                 constrain: bool = True, # constrains sigma
                 use_standardization: bool = True, # standardizes inputs
                 ):
        super().__init__(constrain, use_standardization)
        self.summarizer = summarizer
        self.posterior = posterior
        self.num_sum = sum(p.numel() for p in self.summarizer.parameters() if p.requires_grad)
        self.num_inf = sum(p.numel() for p in self.posterior.parameters() if p.requires_grad)
        self.id = model_id
        
    @classmethod
    def build(cls,
              d_data: int, # input dimension
              s_dict: dict[str, int|float|str] = summary_defaults,
              p_dict: dict[str, int|float|str] = posterior_defaults,
              m_dict: dict[str, int|float|str] = model_defaults,
              use_standardization: bool = True):

        s_dict = summary_defaults | s_dict
        p_dict = posterior_defaults | p_dict
        m_dict = model_defaults | m_dict
        d_data = m_dict['d']
        model_id = cls.modelID(s_dict, p_dict, m_dict)
        
        # 1. summary network
        sum_type = s_dict['type']
        if sum_type == 'deepset':
            summarizer = DeepSet(d_input=d_data, **s_dict)
        elif sum_type == 'poolformer':
            summarizer = PoolFormer(d_input=d_data, **s_dict)
        else:
            raise ValueError(f'transformer type {sum_type} unknown')
        
        # dimension variables
        post_type = p_dict['type']
        d_target = d_data
        prior_dims = 2 * (d_data - 1) + 1
        d_context = s_dict['d_output'] + 1 + prior_dims # additional summary variables: dataset size and priors
            
        # 2. posterior network        
        if post_type in ['flow-affine', 'flow-spline']:
            transform = 'affine' if post_type == 'flow-affine' else 'rq'
            posterior = CouplingPosterior(
                d_target=d_target,
                d_context=d_context,
                n_flows=p_dict['flows'],
                transform=transform,
                net_kwargs=p_dict)
        else:
            raise ValueError
        
        
        return cls(summarizer, posterior,
                   model_id=model_id,
                   use_standardization=use_standardization)
    
    @property 
    def calibrator(self):
        return self.posterior.calibrator
    
    def inputs(self, data: dict[str, torch.Tensor]) -> torch.Tensor:
        y = data['y'].unsqueeze(-1)
        X = data['X'][..., 1:]
        if self.use_standardization:
            mask = data['mask_n'].unsqueeze(-1)
            y = self.standardize(y, 'y', mask=mask)
            X = self.standardize(X, 'X', mask=mask)
        return torch.cat([y, X], dim=-1)
    
    
    def names(self, data: dict[str, torch.Tensor], local: bool = False) -> np.ndarray:
        names = [rf'$\beta_{{{i}}}$' for i in range(data['ffx'].shape[1])]
        names += [r'$\sigma_e$']
        return np.array(names)
    

    def targets(self, data: dict[str, torch.Tensor], local: bool = False) -> torch.Tensor:
        ffx = data['ffx'].clone()
        sigma = data['sigma_eps'].unsqueeze(-1).clone()
        return torch.cat([ffx, sigma], dim=-1)
    
    
    def addMetadata(self, summary: torch.Tensor, data: dict) -> torch.Tensor:
        ''' append summary tensor with n_obs and priors'''
        # number of observations
        out = [summary]
        out += [data['n'].unsqueeze(-1).sqrt() / 100]
        
        # priors
        nu = data['nu_ffx'][..., 1:].clone()
        tau = data['tau_ffx'][..., 1:].clone()
        tau_e = data['tau_eps'].unsqueeze(-1).clone()
        
        if self.use_standardization:
            b, d = nu.shape
            std_y, std_X = self.unpackStd(['y', 'X']).values()
            nu *= std_X.view(b,d) / std_y.view(b,1)
            tau *= std_X.view(b,d) / std_y.view(b,1)
            tau_e /= std_y.view(b,1)
            
        # constrain and concat
        nu = dampen(nu)
        tau = dampen(tau)
        tau_e = dampen(tau_e)
        out += [nu, tau, tau_e]
        return torch.cat(out, dim=-1)
    
    
    def preprocess(self, targets: torch.Tensor, data: dict[str, torch.Tensor]) -> torch.Tensor:
        ''' analytically standardize targets and constrain variance components '''
        slopes = targets[:, 1:-1]
        sigma = targets[:, -1:] 
        
        if self.use_standardization:
            b, d = targets.shape
            d -= 2
            std_y, std_X = self.unpackStd(['y', 'X']).values()
            slopes *= std_X.view(b,d) / std_y.view(b,1)
            sigma /= std_y.view(b,1)
            
        if self.constrain:
            sigma = maskedInverseSoftplus(sigma)
            
        targets = torch.cat([slopes, sigma], dim=-1)
        return targets
    
    
    def postprocess(self,
                    proposed: dict[str, torch.Tensor],
                    data: dict[str, torch.Tensor]) -> torch.Tensor:
        ''' reverse steps used in preprocessing for samples '''
        if 'samples' in proposed:
            samples = proposed['samples'].clone()
            b, d, s = samples.shape
            d -= 1
            
            # constrain sigma
            if self.constrain:
                samples[:, -1] = maskedSoftplus(samples[:, -1])
                
            # unstandardize
            if self.use_standardization:    
                std_y, std_X = self.unpackStd(['y', 'X']).values()
                samples *= std_y.view(b,1,1)
                samples[:, :-1] /= std_X.view(b,d,1)
                
            # recover intercept
            mask = data['mask_n'].unsqueeze(-1)
            y = data['y'].unsqueeze(-1)
            X = data['X'][..., 1:]
            y_mean = maskedMean(y, 1, mask).view(b,1)
            x_mean = maskedMean(X, 1, mask).view(b,d)
            beta = samples[:, :-1]
            xb = torch.einsum('bd,bds->bs', x_mean, beta)
            intercept = (y_mean - xb).unsqueeze(1)
            
            # put everything back together
            samples = torch.cat([intercept, samples], dim=1)
            proposed['samples'] = samples
        return proposed


    def forward(self, data: dict[str, torch.Tensor],
                sample=False, n=(500,), **kwargs,
                ) -> dict[str, torch.Tensor|dict]:
        inputs = self.inputs(data)
        targets = self.targets(data)
        targets = self.preprocess(targets, data)
        summary = self.summarizer(inputs, mask=data['mask_n'])
        context = self.addMetadata(summary, data)
        loss, proposed = self.posterior(
            context, targets, sample=sample, n=n[0], **kwargs)
        proposed = self.postprocess(proposed, data)
        proposed = {'global': proposed}
        return {'loss': loss, 'proposed': proposed}
    
    
    def estimate(self, data: dict[str, torch.Tensor], n=(500,)):
        with torch.no_grad():
            h = self.inputs(data)
            summary = self.summarizer(h)
            summary = self.addMetadata(summary, data)
            mask = torch.cat([data['mask_d'], torch.ones(len(h), 1)], dim=-1).float()
            proposed = self.posterior.estimate(summary, mask, n[0])
            proposed = self.postprocess(proposed, data)
        return proposed
    
    
# -----------------------------------------------------------------------------
# mixed effects
class ApproximatorMFX(Approximator):
    def __init__(self,
                 summarizer_g: Summarizer, # global summarizer
                 summarizer_l: Summarizer, # local summarizer
                 posterior_g: Posterior, # global posterior 
                 posterior_l: Posterior, # local posterior
                 model_id: str,
                 constrain: bool = True, # constrains sigmas
                 use_standardization: bool = True, # standardizes inputs
                 ):
        super().__init__(constrain, use_standardization)
        self.summarizer_g = summarizer_g
        self.summarizer_l = summarizer_l
        self.posterior_g = posterior_g
        self.posterior_l = posterior_l
        num_sum_g = sum(p.numel() for p in self.summarizer_g.parameters() if p.requires_grad)
        num_sum_l = sum(p.numel() for p in self.summarizer_l.parameters() if p.requires_grad)
        num_inf_g = sum(p.numel() for p in self.posterior_g.parameters() if p.requires_grad)
        num_inf_l = sum(p.numel() for p in self.posterior_l.parameters() if p.requires_grad)
        self.num_sum = num_sum_g + num_sum_l
        self.num_inf = num_inf_g + num_inf_l
        self.id = model_id
    

    @classmethod
    def build(cls,
              s_dict: dict[str, int|float|str] = summary_defaults,
              p_dict: dict[str, int|float|str] = posterior_defaults,
              m_dict: dict[str, int|float|str] = model_defaults,
              use_standardization: bool = True):
        
        s_dict = summary_defaults | s_dict
        p_dict = posterior_defaults | p_dict
        m_dict = model_defaults | m_dict
        d_ffx = m_dict['d']
        d_rfx = m_dict['q']
        model_id = cls.modelID(s_dict, p_dict, m_dict)
        cls.d = d_ffx
        cls.q = d_rfx
        cls.r = int((d_ffx-1) * (d_ffx-2) / 2)
        
        # 1. summary networks
        sum_type = s_dict['type']
        s_dict_l = s_dict.copy()
        s_dict_l['d_output'] -= 1
        d_input_l = 1 + (d_ffx - 1) + (d_rfx - 1)
        d_input_g = s_dict_l['d_output'] + 1 # num obs per group 
        
        if sum_type == 'deepset':
            Summarizer = DeepSet  
        elif sum_type == 'poolformer':
            Summarizer = PoolFormer 
        else:
            raise ValueError(f'transformer type {sum_type} unknown')
        summarizer_l = Summarizer(d_input=d_input_l, **s_dict_l)
        summarizer_g = Summarizer(d_input=d_input_g, **s_dict)
        
        # dimension variables
        post_type = p_dict['type']
        d_var = 1 + d_rfx # variance components
        prior_dims = 2 * d_ffx + d_var # ffx prior (nu, tau_f), rfx variance prior (tau_r), noise prior (tau_e)
        d_context_g = s_dict['d_output'] + 2 + prior_dims # global conditional: global summary, num groups, num obs, priors
        # d_context_g += 1 + (d_rfx - 1) # std_y, std_Z
        # d_context_g += cls.r # unique correlation coefs
        d_context_l = d_input_g + d_ffx + d_var # local conditional: local summary, global parameters
        
        # 2. posterior networks
        if post_type in ['flow-affine', 'flow-spline']:
            transform = 'affine' if post_type == 'flow-affine' else 'rq'
            posterior_g = CouplingPosterior(
                d_target=d_ffx+d_var,
                d_context=d_context_g,
                n_flows=p_dict['flows'],
                transform=transform,
                base_type='student',
                net_kwargs=p_dict)
            posterior_l = CouplingPosterior(
                d_target=d_rfx,
                d_context=d_context_l,
                n_flows=p_dict['flows'],
                transform=transform,
                base_type='student',
                net_kwargs=p_dict)
        else:
            raise ValueError
        
        return cls(summarizer_g, summarizer_l,
                   posterior_g, posterior_l,
                   model_id=model_id,
                   use_standardization=use_standardization,
                   )
    
    @property 
    def posterior(self):
        return self.posterior_g
    
    @property 
    def calibrator(self):
        return self.posterior_g.calibrator
    
    @property 
    def calibrator_l(self):
        return self.posterior_l.calibrator

    def moments(self, proposed: torch.Tensor | dict[str, torch.Tensor],
                local: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
        ''' wrapper for location and scale of the posterior '''
        if 'samples' in proposed:
            if local:
                return self.posterior_l.getLocScale(proposed)
            return self.posterior_g.getLocScale(proposed)
        
    
    def quantiles(self, proposed: torch.Tensor | dict[str, torch.Tensor],
                  roots: list = [.025, .975],
                  calibrate: bool = False,
                  local: bool = False) -> torch.Tensor:
        ''' wrapper for desired quantiles of the posterior '''
        if 'samples' in proposed:
            samples = proposed['samples'].clone()
            weights = proposed.get('weights', None)
            if local:
                p = self.posterior_l
                m = samples.shape[1]
                quantiles = [
                    p.getQuantiles(samples[:, i], roots, calibrate).unsqueeze(1)
                    for i in range(m)]
                return torch.cat(quantiles, dim=1)
            p = self.posterior_g
            return p.getQuantiles(samples, roots, calibrate, weights=weights)
    

    def inputs(self, data: dict[str, torch.Tensor], long: bool = False) -> torch.Tensor:
        y = data['y'].unsqueeze(-1)
        X = data['X'][..., 1:self.d] 
        Z = data['Z'][..., 1:self.q]
        if self.use_standardization:
            mask = data['mask_n'].unsqueeze(-1)
            y = self.standardize(y, 'y', mask)
            X = self.standardize(X, 'X', mask)
            Z = self.standardize(Z, 'Z', mask)
        out = torch.cat([y, X, Z], dim=-1)
        return out
    
    
    def names(self, data: dict[str, torch.Tensor], local: bool = False) -> np.ndarray:
        if local:
            names = [rf'$\alpha_{{{i}}}$' for i in range(data['rfx'].shape[2])]
        else:
            names = (
                [rf'$\beta_{{{i}}}$' for i in range(data['ffx'].shape[1])] +
                [rf'$\sigma_{i}$' for i in range(data['sigmas_rfx'].shape[1])] +
                [r'$\sigma_e$']
                )
        return np.array(names)


    def targets(self, data: dict[str, torch.Tensor], local: bool = False) -> torch.Tensor:
        if local:
            out = data['rfx']
        else:
            out = [data['ffx'], data['sigmas_rfx'], data['sigma_eps'].unsqueeze(-1)]
            out = torch.cat(out, dim=-1)
        return out


    def addMetadata(self,
                    summary: torch.Tensor,
                    data: dict,
                    local: bool = False) -> torch.Tensor:
        ''' append summary tensor with n_obs and priors '''
        if local:
            # number of group observations
            out = [summary]
            out += [data['n_i'].unsqueeze(-1).sqrt() / 10]
            
        else:
            # number of groups and total number of observations
            out = [summary]
            out += [data['m'].unsqueeze(-1).sqrt() / 10,
                    data['n'].unsqueeze(-1).sqrt() / 10,
                    # data['R'], # correlations
                    ]
            
            # prior params
            nu_f = data['nu_ffx'].clone()
            tau_f = data['tau_ffx'].clone()
            tau_r = data['tau_rfx'].clone()
            tau_e = data['tau_eps'].unsqueeze(-1).clone()
            
            if self.use_standardization:
                # standardize priors
                b, d, q = len(nu_f), self.d, self.q
                std_y, std_X, std_Z = self.unpackStd(['y', 'X', 'Z']).values()
                nu_f /= std_y.view(b,1)
                nu_f[:, 1:] *= std_X.view(b,d-1)
                tau_f /= std_y.view(b,1)
                tau_f[:, 1:] *= std_X.view(b,d-1)
                tau_r /= std_y.view(b,1)
                tau_r[:, 1:] *= std_Z.view(b,q-1)
                tau_e /= std_y.view(b,1)
                
                # add stds
                # out += [dampen(s.view(b,-1)) for s in [std_y, std_Z]]
                
                
            # reduce abolute size for better NN handling
            nu_f = dampen(nu_f)
            tau_f = dampen(tau_f)
            tau_r = dampen(tau_r)
            tau_e = dampen(tau_e)
            out += [nu_f, tau_f, tau_r, tau_e]
        return torch.cat(out, dim=-1)


    def preprocess(self,
                   targets: torch.Tensor,
                   data: dict[str, torch.Tensor],
                   local: bool = False) -> torch.Tensor:
        ''' analytically standardize targets and constrain variance components '''
        targets = targets.clone()
        
        # prepare moments
        if self.use_standardization:
            mean_y, mean_X, mean_Z = self.unpackMean(['y', 'X', 'Z']).values()
            std_y, std_X, std_Z = self.unpackStd(['y', 'X', 'Z']).values()
        
        # local parameters
        if local:
            rfx = targets
            if self.use_standardization:
                b, q = len(rfx), self.q
                
                # standardize rfx
                rfx_ = rfx / std_y.view(b,1,1)
                rfx_[..., 1:] *= std_Z.view(b,1,q-1)
                mean_Zb = (mean_Z.view(b, 1, q-1) * rfx[..., 1:]).sum(2)
                rfx_[..., 0] = (rfx[..., 0] + mean_Zb) / std_y.view(b,1)
                
                # patch targets
                rfx = rfx_
            # put everything back together
            targets = rfx
            
        # global parameters
        else:
            b, d, q = len(targets), self.d, self.q
            ffx, sigmas_rfx, sigma_eps = targets[:, :d], targets[:, d:-1], targets[:, -1:] 
            
            if self.use_standardization:
                
                # standardize ffx
                ffx_ = ffx / std_y.view(b,1)
                ffx_[:, 1:] *= std_X.view(b,d-1)
                mean_Xb = (mean_X.view(b, d-1) * ffx[:, 1:]).sum(1)
                ffx_[:, 0] = (ffx[:, 0] + mean_Xb - mean_y.view(b)) / std_y.view(b) #this might be too small with std_y
                
                # standardize sigmas
                sigma_eps /= std_y.view(b,1)
                sigmas_rfx_ = sigmas_rfx / std_y.view(b,1)
                sigmas_rfx_[:, 1:] *= std_Z.view(b,q-1) # random slopes
                
                # sigma intercept with covsum
                cov_sum =  data['cov_sum'] # sum of the mean covariance between Z and rfx 
                sigmas_rfx_[:, 0] = (sigmas_rfx[:, 0].square() + cov_sum).sqrt() / std_y.view(b)
                
                # alternative: simplified sigma intercept
                # mean_Zsigma = (mean_Z.view(b,d-1).abs() * sigmas_rfx[:, 1:]).sum(-1)
                # sigmas_rfx_[:, 0] = (sigmas_rfx[:, 0] + mean_Zsigma) / std_y.view(b) # random intercept
                
                # patch targets
                ffx = ffx_
                sigmas_rfx = sigmas_rfx_                
            
            # project positives to reals
            if self.constrain:
                sigmas_rfx = maskedInverseSoftplus(sigmas_rfx + 1e-6)
                sigma_eps = maskedInverseSoftplus(sigma_eps + 1e-6)
            
            # put everything back together
            targets = torch.cat([ffx, sigmas_rfx, sigma_eps], dim=-1)
        return targets

    
    def postprocess(self,
                    proposed: dict[str, dict[str, torch.Tensor]],
                    data: dict[str, torch.Tensor]) -> torch.Tensor:
        ''' reverse steps used in preprocessing for samples '''
        if 'samples' not in proposed['global']:
            return proposed
        
        if self.use_standardization:
            mean_y, mean_X, mean_Z = self.unpackMean(['y', 'X', 'Z']).values()
            std_y, std_X, std_Z = self.unpackStd(['y', 'X', 'Z']).values()
    
        # local postprocessing
        rfx_ = proposed['local']['samples'].clone()
        if self.use_standardization:
            b, m, q, s = rfx_.shape
            
            # standardize rfx
            rfx = rfx_ * std_y.view(b,1,1,1)
            rfx[..., 1:, :] /= std_Z.view(b,1,q-1,1)
            mean_Zb = (mean_Z.view(b,1,q-1,1) * rfx[..., 1:, :]).sum(2)
            rfx[..., 0, :] = rfx_[..., 0, :] * std_y.view(b,1,1) - mean_Zb
            
            # patch samples
            rfx_ = rfx
        proposed['local']['samples'] = rfx_
        
        # global postprocessing
        samples = proposed['global']['samples'].clone()
        b, _, s = samples.shape
        d, q = self.d, self.q
        ffx_, sigmas_rfx_, sigma_eps_ = samples[:, :d], samples[:, d:-1], samples[:, -1:] 
        
        # constrain stds to be positive
        if self.constrain:
            sigmas_rfx_ = maskedSoftplus(sigmas_rfx_)
            sigma_eps_ = maskedSoftplus(sigma_eps_)
            
        # analytical unstandardization
        if self.use_standardization:
            
            # unstandardize ffx
            onesies = (data['d'] == 1)
            ffx_[onesies, 0] = 0 # in pure intercept models the standardized intercept is 0
            ffx = ffx_ * std_y.view(b,1,1)
            ffx[:, 1:] /= std_X.view(b,d-1,1)
            mean_Xb = (mean_X.view(b,d-1,1) * ffx[:, 1:]).sum(1)
            ffx[:, 0] = ffx_[:, 0] * std_y.view(b,1) - mean_Xb + mean_y.view(b,1)
            
            # unstandardize sigmas
            sigma_eps_ *= std_y.view(b,1,1)
            sigmas_rfx = sigmas_rfx_ * std_y.view(b,1,1)
            sigmas_rfx[:, 1:] /= std_Z.view(b,q-1,1) # random slopes
            
            # sigma intercept with cov_sum
            ones = torch.ones_like(mean_X[..., 0:1])
            mean_Z1 = torch.cat([ones, mean_Z], dim=-1)
            weighted = rfx_.mean(-1) * mean_Z1.view(b,1,q)
            cov = batchCovary(weighted, data['mask_m'])
            cov_sum = (cov.sum((-1,-2)) - cov[:, 0,0]).unsqueeze(-1)
            sigma_0 = (sigmas_rfx[:, 0].square() - cov_sum).clamp(min=1e-12).sqrt()
            # ub = sigma_0.mean(-1).view(-1).topk(3)[0][-1]
            # sigmas_rfx[:, 0] = sigma_0.clamp(max=ub)
            sigmas_rfx[:, 0] = sigma_0
            
            # alternative: simplified sigma intercept
            # mean_Zsigma = (mean_Z.view(b,d-1,1).abs() * sigmas_rfx[:, 1:]).sum(-1)
            # sigmas_rfx[:, 0] -= mean_Zsigma
            
            # patch samples
            ffx_ = ffx
            sigmas_rfx_ = sigmas_rfx
            
            
        proposed['global']['samples'] = torch.cat([ffx_, sigmas_rfx_, sigma_eps_], dim=1)
        return proposed
    
    
    def forward(self, data: dict[str, torch.Tensor],
                sample=False, n=(300,200), log_prob=False, **kwargs
                ) -> dict[str, torch.Tensor|dict]:
        # prepare
        proposed = {}
        inputs = self.inputs(data)
        b, m, _, _ = inputs.shape
        # check(inputs)
        
        # local summaries
        summaries = self.summarizer_l(inputs, data['mask_n'])
        summaries = self.addMetadata(summaries, data, local=True)
        # check(summaries)
        
        # global summary 
        mask_m = None if self.training else data['mask_m']
        summary = self.summarizer_g(summaries, mask_m)
        context_g = self.addMetadata(summary, data, local=False)
        # check(context_g)
        
        # global inference
        targets_g = self.targets(data, local=False)
        targets_g = self.preprocess(targets_g, data, local=False)
        # check(targets_g)
        loss, proposed['global'] = self.posterior_g(
            context_g, targets_g, sample=sample, n=n[0])
        # check(loss)
        
        # local inference
        targets_l = self.targets(data, local=True)
        targets_l = self.preprocess(targets_l, data, local=True)
        if sample:
            global_params = proposed['global']['samples'].mean(-1).to(summaries.device)
        else:
            global_params = targets_g
        global_params = global_params.view(b, 1, -1).expand(b, m, -1)
        context_l = torch.cat([summaries, global_params], dim=-1)
        loss_l, proposed['local'] = self.posterior_l(
            context_l, targets_l, sample=sample, n=n[1])
        # check(loss_l)
        
        # postprocessing
        proposed = self.postprocess(proposed, data)
        loss += loss_l.sum(-1) / data['m']
        return {'loss': loss, 'proposed': proposed}


    def estimate(self, data: dict[str, torch.Tensor], n=(300,200)):
        with torch.no_grad():
            proposed = {}
            inputs = self.inputs(data)
            b, m, _, _ = inputs.shape
            mask_g = torch.cat([
                data['mask_d'], # ffx
                data['mask_q'], # sigmas rfx
                torch.ones(b, 1), # sigma eps
                ], dim=-1).float()
            mask_l = data['mask_q'].unsqueeze(1).expand(b,m,-1).float()
            
            # summaries
            summaries = self.summarizer_l(inputs, data['mask_n'])
            summaries = self.addMetadata(summaries, data, local=True)
            summary = self.summarizer_g(summaries, data['mask_m'])
            
            # global inference
            context_g = self.addMetadata(summary, data, local=False)
            proposed['global'] = self.posterior_g.estimate(context_g, mask_g, n[0])
            
            # local inference
            global_params = proposed['global']['samples'].mean(-1)
            global_params = global_params.view(b, 1, -1).expand(b, m, -1)
            context_l = torch.cat([summaries, global_params], dim=-1)
            proposed['local'] = self.posterior_l.estimate(context_l, mask_l, n[1])
            
            # postprocessing
            proposed = self.postprocess(proposed, data)
        return proposed

