import torch
import torch.nn as nn
from torch.distributions import kl_divergence
from attrdict import AttrDict

from torch.distributions import Normal, Categorical
from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical

from ..utils.misc import stack, logmeanexp, log_w_weighted_sum_exp
from ..utils.sampling import sample_subset
from .modules import PoolingEncoder, Decoder
from typing import Callable

def grad_U(U: Callable, x: torch.Tensor) -> torch.Tensor:
    """
    Autograd-based gradient of U with respect to x.
    We use PyTorch's automatic differentiation for generality.
    """
    x = x.clone().requires_grad_(True)
    potential = U(x).sum()  # sum over batch if needed
    grad, = torch.autograd.grad(potential, x)
    return grad

# -----------------------------
# 2. Define the leapfrog integrator
#    which simulates Hamiltonian dynamics
# -----------------------------
def leapfrog(
    U: Callable,
    x: torch.Tensor,
    p: torch.Tensor,
    step_size: float,
    n_steps: int):
    """
    Leapfrog integrator for Hamiltonian dynamics.
    Returns the proposed new position x and momentum p.
    """
    # Half-step update of momentum
    p = p - 0.5 * step_size * grad_U(U, x)
    
    for i in range(n_steps):
        # Full-step update of position
        x = x + step_size * p
        # Update momentum except after the last step
        if i < n_steps - 1:
            p = p - step_size * grad_U(U,x)

    # Final half-step update of momentum
    p = p - 0.5 * step_size * grad_U(U,x)
    return x, p

# -----------------------------
# 3. HMC step: sample momentum, 
#    do leapfrog, accept/reject
# -----------------------------
def hmc_step(
    U: Callable,
    x_current: torch.Tensor,
    step_size: float,
    n_steps: int):
    """
    Performs one HMC update step in parallel for a batch of states.
    x_current has shape [B, D].
    Returns:
      x_next: shape [B, D]
      accept_mask: shape [B], indicating which chains were accepted.
    """
    # Sample random momentum from N(0, I) for each chain
    p_current = torch.randn_like(x_current)  # shape [B, D]

    # Current Hamiltonian (U + K)
    U_current = U(x_current)                        # shape [B]
    K_current = 0.5 * (p_current ** 2).sum(dim=-1)  # shape [B]

    # Propose new state with leapfrog
    x_proposed, p_proposed = leapfrog(U, x_current, p_current,
                                      step_size, n_steps)

    # Proposed Hamiltonian
    U_proposed = U(x_proposed)                      # shape [B]
    K_proposed = 0.5 * (p_proposed ** 2).sum(dim=-1)# shape [B]

    # Metropolis acceptance probability = exp(H_current - H_proposed)
    # log_accept_ratio has shape [B]
    log_accept_ratio = (U_current + K_current) - (U_proposed + K_proposed)
    accept_prob = torch.exp(log_accept_ratio)  # shape [B]

    # Decide which samples to accept
    accept_mask = (torch.rand_like(accept_prob) < accept_prob)  # [B]

    # Where accept_mask is True, use x_proposed; otherwise keep x_current
    # We broadcast accept_mask over the D dimension
    x_next = torch.where(
        accept_mask.unsqueeze(-1),  # shape [B, 1]
        x_proposed,
        x_current
    )

    return x_next, accept_mask


class ForwardTransition(nn.Module):
    def __init__(self, d_model):
        super(ForwardTransition, self).__init__()
        self.predictor = nn.Sequential(
            nn.Linear(2*d_model, 2*d_model),
            nn.ReLU(),
            nn.Linear(2*d_model, 2*d_model),
            nn.ReLU(),
            nn.Linear(2*d_model, 2*d_model)
        )
    
    def forward(self, psi):
        out = self.predictor(psi)
        mean, std = torch.chunk(out, 2, dim=-1)
        # std = 0.01+0.09*torch.sigmoid(std)
        std = 0.01*torch.exp(std)
        return Normal(torch.chunk(psi, 2, dim=-1)[0]+mean, std)


class BackwardTransition(nn.Module):
    def __init__(self, d_model):
        super(BackwardTransition, self).__init__()
        self.predictor = nn.Sequential(
            nn.Linear(2*d_model, 2*d_model),
            nn.ReLU(),
            nn.Linear(2*d_model, 2*d_model),
            nn.ReLU(),
            nn.Linear(2*d_model, 2*d_model)
        )
    
    def forward(self, psi):
        out = self.predictor(psi)
        mean, std = torch.chunk(out, 2, dim=-1)
        std = 0.1*torch.exp(std)
        return Normal(torch.chunk(psi, 2, dim=-1)[0]+mean, std)

class NP_GIBBS_HS(nn.Module):
    def __init__(self,
            dim_x=1,
            dim_y=1,
            dim_hid=128,
            dim_lat=128,
            enc_pre_depth=4,
            enc_post_depth=2,
            dec_depth=3):

        super().__init__()

        self.denc = PoolingEncoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_hid=dim_hid,
                pre_depth=enc_pre_depth,
                post_depth=enc_post_depth)

        self.lenc = PoolingEncoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_hid=dim_hid,
                dim_lat=dim_lat,
                pre_depth=enc_pre_depth,
                post_depth=enc_post_depth)

        self.dec = Decoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_enc=dim_hid+dim_lat,
                dim_hid=dim_hid,
                depth=dec_depth)
        
        self.forward_transition = ForwardTransition(d_model=dim_lat)
        self.backward_transition = BackwardTransition(d_model=dim_lat)

    def predict(self, *args, **kwargs):
        return self.predict_v11(*args, **kwargs)

    def predict_v11(self, xc, yc, xt, z=None, num_samples=None, all=False, avg=False, iters=5): # theta only
        print("iters", iters, "samples", num_samples)
        xt = stack(xt, num_samples)
        xc = stack(xc, num_samples)
        yc = stack(yc, num_samples)
        init_xc = xc.clone()
        init_yc = yc.clone()
        assert z is None, "Eval Only"

        accept = 0.

        theta = self.denc(xc, yc)
        pz = self.lenc(xc, yc)
        z = pz.rsample()
        encoded = torch.cat([theta, z], -1) # (50, 16, 256)
        _encoded = stack(encoded, xt.shape[-2], -2)
        py = self.dec(_encoded, xt)
        y = py.rsample() # (50, 16, 26, 1)

        def target(var):
            z = var

            encoded = torch.cat([theta, z], -1)
            _encoded_xc = stack(encoded, xc.shape[-2], -2)  # (50, 16, 26, 256)
            py_xc = self.dec(_encoded_xc, xc)
            logqz = py_xc.log_prob(yc).sum(-1).sum(-1)
            logprior = Normal(loc=0, scale=1).log_prob(encoded).sum(-1)
            
            return -logqz - logprior

        # --------------------------------------------------
        # MCMC
        # --------------------------------------------------
        py_list = []
        avg_var = 0
        var = z
        for i in range(iters):
            var, accept_mask = hmc_step(target, var, 0.01, 4)
            accept += accept_mask.float().mean().item()
            avg_var += var / iters
            encoded = torch.cat([theta, var], -1)
            _encoded = stack(encoded, xt.shape[-2], -2)
            py = self.dec(_encoded, xt)
            py_list.append(py)
        accept /= iters
        assert accept > 0.5, accept
        # --------------------------------------------------
        if all:
            return py_list
        if avg:
            encoded = avg_var
        return py

    def predict_v10(self, xc, yc, xt, z=None, num_samples=None): # theta only
        xt = stack(xt, num_samples)
        xc = stack(xc, num_samples) # (50, 16, 26, 1)
        yc = stack(yc, num_samples)
        init_xc = xc.clone()
        init_yc = yc.clone()
        assert z is None, "Eval Only"
        def target(var):
            S, B, T, C = xc.shape
            encoded = var
            theta, z = torch.chunk(encoded, 2, dim=-1)
            ctx_len = xc.shape[-2]
            # num_aug = min(ctx_len, 5)
            # exclude_len = ctx_len//num_aug
            num_aug = 1
            exclude_len = 1

            idx = torch.rand(
                num_aug, ctx_len).to(xt.device).argsort(dim=1)
            keep_idx = idx[:, :ctx_len-exclude_len]
            no_idx = idx[:, ctx_len-exclude_len:]

            ex_xc = xc.unsqueeze(0).repeat_interleave(num_aug, dim=0)
            ex_yc = yc.unsqueeze(0).repeat_interleave(num_aug, dim=0)
            # _idx_x = keep_idx.view(num_aug, 1, 1, -1, 1).expand(-1, S, B, -1, ex_xc.size(-1))
            # _idx_y = keep_idx.view(num_aug, 1, 1, -1, 1).expand(-1, S, B, -1, ex_yc.size(-1))
            no_idx_x = no_idx.view(num_aug, 1, 1, -1, 1).expand(-1, S, B, -1, xc.size(-1))
            no_idx_y = no_idx.view(num_aug, 1, 1, -1, 1).expand(-1, S, B, -1, yc.size(-1))
            no_xc = ex_xc.gather(dim=-2, index=no_idx_x)
            no_yc = ex_yc.gather(dim=-2, index=no_idx_y)
            # k_xc = ex_xc.gather(dim=-2, index=_idx_x)
            # k_yc = ex_yc.gather(dim=-2, index=_idx_y)
            # ex_theta = self.denc(k_xc, k_yc)
            # ex_pz = self.lenc(k_xc, k_yc)
            # logq = ex_pz.log_prob(z).sum(-1) # (num_aug, 50, 16)
            _encoded = stack(stack(encoded, no_xc.shape[-2], -2), num_aug, 0)
            ex_py = self.dec(_encoded, no_xc)
            logp = ex_py.log_prob(no_yc).sum(-1).sum(-1) # (num_aug, 50, 16)
            
            logprior = Normal(loc=0, scale=1).log_prob(encoded).sum(-1) # (50, 16)
            
            # return -logp.mean(0) -logq.mean(0) - logprior
            return -logp.mean(0) - logprior
        accept = 0.

        theta = self.denc(xc, yc)
        pz = self.lenc(xc, yc)
        z = pz.rsample()
        encoded = torch.cat([theta, z], -1) # (50, 16, 256)
        _encoded = stack(encoded, xt.shape[-2], -2)
        py = self.dec(_encoded, xt)
        y = py.rsample() # (50, 16, 26, 1)
        var = encoded
        # --------------------------------------------------
        # MCMC
        # --------------------------------------------------
        iters = 5
        for i in range(iters):
            var, accept_mask = hmc_step(target, var, 0.01, 2)
            accept += accept_mask.float().mean().item()
        accept /= iters
        assert accept > 0.5, accept
        encoded = var
        # --------------------------------------------------
        _encoded = stack(encoded, xt.shape[-2], -2)
        py = self.dec(_encoded, xt)
        return py

    def predict_v9(self, xc, yc, xt, z=None, num_samples=None, all=False, avg=False, iters=10): # theta only
        print("iters", iters, "samples", num_samples)
        xt = stack(xt, num_samples)
        xc = stack(xc, num_samples)
        yc = stack(yc, num_samples)
        init_xc = xc.clone()
        init_yc = yc.clone()
        assert z is None, "Eval Only"
        def target(var):
            encoded = var

            _encoded_xc = stack(encoded, xc.shape[-2], -2)  # (50, 16, 26, 256)
            py_xc = self.dec(_encoded_xc, xc)
            logqz = py_xc.log_prob(yc).sum(-1).sum(-1)
            logprior = Normal(loc=0, scale=1).log_prob(encoded).sum(-1)
            
            return -logqz - logprior
        accept = 0.

        theta = self.denc(xc, yc)
        pz = self.lenc(xc, yc)
        z = pz.rsample()
        encoded = torch.cat([theta, z], -1) # (50, 16, 256)
        _encoded = stack(encoded, xt.shape[-2], -2)
        py = self.dec(_encoded, xt)
        y = py.rsample() # (50, 16, 26, 1)
        var = encoded
        # --------------------------------------------------
        # MCMC
        # --------------------------------------------------
        py_list = []
        avg_var = 0
        for i in range(iters):
            var, accept_mask = hmc_step(target, var, 0.01, 5)
            accept += accept_mask.float().mean().item()
            avg_var += var / iters
            _encoded = stack(var, xt.shape[-2], -2)
            py = self.dec(_encoded, xt)
            py_list.append(py)
        accept /= iters
        assert accept > 0.5, accept
        encoded = var
        # --------------------------------------------------
        if all:
            return py_list
        if avg:
            encoded = avg_var
        _encoded = stack(encoded, xt.shape[-2], -2)
        py = self.dec(_encoded, xt)
        return py


    def predict_v8failed(self, xc, yc, xt, z=None, num_samples=None): # theta, y joint
        xt = stack(xt, num_samples)
        xc = stack(xc, num_samples)
        yc = stack(yc, num_samples)
        init_xc = xc.clone()
        init_yc = yc.clone()
        assert z is None, "Eval Only"
        def target(var):
            encoded, y = torch.split(var, [256, var.size(-1)-256], dim=-1)
            y = y.unsqueeze(-1)
            theta, z = torch.chunk(encoded, 2, dim=-1)
            pz = self.lenc(xc, yc)
            logqz = pz.log_prob(z).sum(-1)
            _encoded = stack(encoded, xt.shape[-2], -2)  # (50, 16, 26, 256)
            py = self.dec(_encoded, xt)
            logp = py.log_prob(y).sum(-1).sum(-1)


            # _encoded = stack(encoded, xt.shape[-2], -2)  # (50, 16, 26, 256)
            # py = self.dec(_encoded, xt)
            # _encoded_xc = stack(encoded, xc.shape[-2], -2)  # (50, 16, 26, 256)
            # py_xc = self.dec(_encoded_xc, xc)
            # logqz = py_xc.log_prob(yc).sum(-1).sum(-1)
            # logp = py.log_prob(y).sum(-1).sum(-1)
            # logprior = Normal(loc=0, scale=1).log_prob(encoded).sum(-1)
            
            # return -logp - logqz - logprior
            return -logp -logqz
        accept = 0.

        theta = self.denc(xc, yc)
        pz = self.lenc(xc, yc)
        z = pz.rsample()
        encoded = torch.cat([theta, z], -1) # (50, 16, 256)
        _encoded = stack(encoded, xt.shape[-2], -2)
        py = self.dec(_encoded, xt)
        y = py.rsample() # (50, 16, 26, 1)
        var = torch.cat([encoded, y.view(y.size(0),y.size(1),-1)], dim=-1)
        # --------------------------------------------------
        # MCMC
        # --------------------------------------------------
        iters = 5
        for i in range(iters):
            var, accept_mask = hmc_step(target, var, 0.01, 2)
            accept += accept_mask.float().mean().item()
        accept /= iters
        assert accept > 0.5, accept
        encoded, y = torch.split(var, [256, var.size(-1)-256], dim=-1)
        y = y.unsqueeze(-1)
        # --------------------------------------------------
        _encoded = stack(encoded, xt.shape[-2], -2)
        py = self.dec(_encoded, xt)
        return py
    def predict_v8(self, xc, yc, xt, z=None, num_samples=None): # theta, y joint
        xt = stack(xt, num_samples)
        xc = stack(xc, num_samples)
        yc = stack(yc, num_samples)
        init_xc = xc.clone()
        init_yc = yc.clone()
        assert z is None, "Eval Only"
        def target(var):
            encoded, y = torch.split(var, [256, var.size(-1)-256], dim=-1)
            y = y.unsqueeze(-1)
            # theta, z = torch.chunk(encoded, 2, dim=-1)
            # pz = self.lenc(xc, yc)
            # logqz = pz.log_prob(z).sum(-1)
            # _encoded = stack(encoded, xt.shape[-2], -2)  # (50, 16, 26, 256)
            # py = self.dec(_encoded, xt)
            # logp = py.log_prob(y).sum(-1).sum(-1)

            _encoded = stack(encoded, xt.shape[-2], -2)  # (50, 16, 26, 256)
            py = self.dec(_encoded, xt)
            _encoded_xc = stack(encoded, xc.shape[-2], -2)  # (50, 16, 26, 256)
            py_xc = self.dec(_encoded_xc, xc)
            logqz = py_xc.log_prob(yc).sum(-1).sum(-1)
            logp = py.log_prob(y).sum(-1).sum(-1)
            logprior = Normal(loc=0, scale=1).log_prob(encoded).sum(-1)
            
            return -logp - logqz -logprior 
        accept = 0.

        theta = self.denc(xc, yc)
        pz = self.lenc(xc, yc)
        z = pz.rsample()
        encoded = torch.cat([theta, z], -1) # (50, 16, 256)
        _encoded = stack(encoded, xt.shape[-2], -2)
        py = self.dec(_encoded, xt)
        y = py.rsample() # (50, 16, 26, 1)
        var = torch.cat([encoded, y.view(y.size(0),y.size(1),-1)], dim=-1)
        # --------------------------------------------------
        # MCMC
        # --------------------------------------------------
        iters = 5
        for i in range(iters):
            var, accept_mask = hmc_step(target, var, 0.01, 2)
            accept += accept_mask.float().mean().item()
        accept /= iters
        assert accept > 0.5, accept
        encoded, y = torch.split(var, [256, var.size(-1)-256], dim=-1)
        y = y.unsqueeze(-1)
        # --------------------------------------------------
        _encoded = stack(encoded, xt.shape[-2], -2)
        py = self.dec(_encoded, xt)
        return py
    def predict_v7(self, xc, yc, xt, z=None, num_samples=None, y=None):
        dist, _ = torch.min((xt[...,None,:] - xc[...,None,:,:])**2, dim=-2)
        order = torch.argsort(dist, dim=-2)
        inv_order = torch.argsort(order, dim=-2)
        xt = stack(xt, num_samples)
        xc = stack(xc, num_samples)
        yc = stack(yc, num_samples)
        if y is not None:
            y = stack(y, num_samples)
        init_xc = xc.clone()
        init_yc = yc.clone()
        init_xt = xt.clone()
        assert z is None, "Eval Only"
        py_list = []
        for i in range(xt.shape[-2]):
            theta = self.denc(xc, yc)
            pz = self.lenc(xc, yc)
            z = pz.rsample()
            encoded = torch.cat([theta, z], -1) # (50, 16, 256)
            # --------------------------------------------------
            # encoded filtering
            # --------------------------------------------------
            prior = Normal(loc=0, scale=1).log_prob(z).sum(-1)
            pz.scale = pz.scale/pz.scale
            qz = pz.log_prob(z).sum(-1)
            _encoded = stack(encoded, init_xc.shape[-2], -2)  # (50, 16, 26, 256)
            _py = self.dec(_encoded, init_xc)
            recon = _py.log_prob(init_yc).sum(-1).sum(-1)
            score = recon + prior - qz
            score = torch.moveaxis(score, 0, -1)
            dist = Categorical(logits=score)
            indices = dist.sample([score.size(-1)]) # (50,16)
            encoded = torch.gather(encoded, 0, indices.unsqueeze(-1).expand_as(encoded))
            if i > 0:
                yc = torch.gather(yc, 0, indices[...,None,None].expand_as(yc))
            # --------------------------------------------------
            encoded = stack(encoded, 1, -2)
            idx = order[...,i:i+1, :]
            _xt = torch.gather(xt, -2, idx[None,...].repeat(num_samples, 1, 1, 1))
            py = self.dec(encoded, _xt)
            if y is not None:
                yt = torch.gather(y, -2, idx[None,...].repeat(num_samples, 1, 1, 1))
                yt = Normal(loc=yt, scale=py.scale).rsample()
            else:
                yt = py.rsample()
            xc = torch.cat([xc, _xt], dim=-2)
            yc = torch.cat([yc, yt], dim=-2)
            py_list.append(py)
        loc = torch.cat([py.loc for py in py_list], -2)
        loc = torch.gather(loc, -2, inv_order[None,...].repeat(num_samples, 1, 1, 1))
        scale = torch.cat([py.scale for py in py_list], -2)
        scale = torch.gather(scale, -2, inv_order[None,...].repeat(num_samples, 1, 1, 1))
        py = Normal(loc=loc, scale=scale)
        return py
    def predict_v6(self, xc, yc, xt, z=None, num_samples=None):
        xt = stack(xt, num_samples)
        xc = stack(xc, num_samples)
        yc = stack(yc, num_samples)
        init_xc = xc.clone()
        init_yc = yc.clone()
        init_xt = xt.clone()
        samplez = False
        iters = 5
        assert z is None, "Eval Only"
        for i in range(iters):
            theta = self.denc(xc, yc)
            if z is None or samplez:
                pz = self.lenc(xc, yc)
                z = pz.rsample()
                samplez=True
            encoded = torch.cat([theta, z], -1) # (50, 16, 256)
            # --------------------------------------------------
            # encoded filtering
            # --------------------------------------------------
            # _encoded = stack(encoded, xt.shape[-2], -2)
            # _py = self.dec(_encoded, xt)
            if samplez:
                prior = Normal(loc=0, scale=1).log_prob(z).sum(-1)
                pz.scale = pz.scale/pz.scale
                qz = pz.log_prob(z).sum(-1)
                _encoded = stack(encoded, init_xc.shape[-2], -2)  # (50, 16, 26, 256)
                py = self.dec(_encoded, init_xc)
                recon = py.log_prob(init_yc).sum(-1).sum(-1)
                # print("recon", recon[:,0])
                # print("prior", prior[:,0])
                # print("qz", qz[:,0])
                score = recon + prior - qz
                score = torch.moveaxis(score, 0, -1)
                dist = Categorical(logits=score)
                indices = dist.sample([score.size(-1)]) # (50,16)
                encoded = torch.gather(encoded, 0, indices.unsqueeze(-1).expand_as(encoded))
            # --------------------------------------------------
            if i != iters-1:
                idx = torch.randperm(init_xc.shape[-2]).to(init_xc.device)
                shuffled = torch.gather(init_xc, -2, idx[None, :, None].expand_as(init_xc))
                xt = 0.5*(init_xc+shuffled)
            else:
                xt = init_xt
            encoded = stack(encoded, xt.shape[-2], -2)
            py = self.dec(encoded, xt)
            yt = py.rsample()
            xc = torch.cat([init_xc, xt], dim=-2)
            yc = torch.cat([init_yc, yt], dim=-2)
        return py

    def predict_v5(self, xc, yc, xt, z=None, num_samples=None, **kwargs):
        # print("num_samples", num_samples)
        xt = stack(xt, num_samples)
        xc = stack(xc, num_samples)
        yc = stack(yc, num_samples)
        init_xc = xc.clone()
        init_yc = yc.clone()
        init_xt = xt.clone()
        samplez = False
        iters = 1
        assert z is None, "Eval Only"
        for i in range(iters):
            theta = self.denc(xc, yc)
            if z is None or samplez:
                pz = self.lenc(xc, yc)
                z = pz.rsample()
                samplez=True
            encoded = torch.cat([theta, z], -1) # (50, 16, 256)
            # --------------------------------------------------
            # encoded filtering
            # --------------------------------------------------
            # _encoded = stack(encoded, xt.shape[-2], -2)
            # _py = self.dec(_encoded, xt)
            if samplez:
                prior = Normal(loc=0, scale=1).log_prob(z).sum(-1)
                pz.scale = pz.scale/pz.scale
                qz = pz.log_prob(z).sum(-1)
                _encoded = stack(encoded, init_xc.shape[-2], -2)  # (50, 16, 26, 256)
                py = self.dec(_encoded, init_xc)
                recon = py.log_prob(init_yc).sum(-1).sum(-1)
                score = recon + prior - qz
                score = torch.moveaxis(score, 0, -1)
                dist = Categorical(logits=score)
                indices = dist.sample([score.size(-1)]) # (50,16)
                encoded = torch.gather(encoded, 0, indices.unsqueeze(-1).expand_as(encoded))
            # --------------------------------------------------
            encoded = stack(encoded, xt.shape[-2], -2)
            py = self.dec(encoded, xt)
            yt = py.rsample()
            xc = torch.cat([init_xc, xt], dim=-2)
            yc = torch.cat([init_yc, yt], dim=-2)
        return py
    def predict_v4(self, xc, yc, xt, z=None, num_samples=None):
        xt = stack(xt, num_samples)
        xc = stack(xc, num_samples)
        yc = stack(yc, num_samples)
        init_xc = xc.clone()
        init_yc = yc.clone()
        iters = 2
        samplez = False
        for i in range(iters):
            theta = self.denc(xc, yc)
            if z is None or samplez:
                pz = self.lenc(xc, yc)
                pz.scale = pz.scale*4
                z = pz.rsample()
                samplez = True
            encoded = torch.cat([theta, z], -1) # (50, 16, 256)
            # --------------------------------------------------
            # encoded filtering
            # --------------------------------------------------
            if samplez:
                score = pz.log_prob(z).sum(-1)
                score = torch.moveaxis(score, 0, -1)
                dist = Categorical(logits=score)
                indices = dist.sample([score.size(-1)]) # (50,16)
                encoded = torch.gather(encoded, 0, indices.unsqueeze(-1).expand_as(encoded))
                if i > 0:
                    _xt = torch.gather(prev_xt, 0, indices[...,None,None].expand_as(prev_xt))
                    _yt = torch.gather(prev_yt, 0, indices[...,None,None].expand_as(prev_yt))
                    xc = torch.cat([prev_xc, _xt], dim=-2)
                    yc = torch.cat([prev_yc, _yt], dim=-2)
            # --------------------------------------------------
            encoded = stack(encoded, xt.shape[-2], -2)
            py = self.dec(encoded, xt)
            yt = py.rsample()
            prev_xc = xc.clone()
            prev_xt = xt.clone()
            prev_yc = yc.clone()
            prev_yt = yt.clone()
            xc = torch.cat([xc, xt], dim=-2)
            yc = torch.cat([yc, yt], dim=-2)
        return py
    def predict_v3(self, xc, yc, xt, z=None, num_samples=None):
        xt = stack(xt, num_samples)
        xc = stack(xc, num_samples)
        yc = stack(yc, num_samples)
        init_xc = xc.clone()
        init_yc = yc.clone()
        iters = 2
        samplez = False
        for _ in range(iters):
            theta = self.denc(xc, yc)
            if z is None or samplez:
                pz = self.lenc(xc, yc)
                z = pz.rsample()
                samplez=True
            encoded = torch.cat([theta, z], -1) # (50, 16, 256)
            # --------------------------------------------------
            # encoded filtering
            # --------------------------------------------------
            _encoded = stack(encoded, init_xc.shape[-2], -2)  # (50, 16, 26, 256)
            py = self.dec(_encoded, init_xc)
            recon = py.log_prob(init_yc).sum(-1).sum(-1)
            score = torch.moveaxis(recon, 0, -1)
            dist = Categorical(logits=score)
            indices = dist.sample([score.size(-1)]) # (50,16)
            encoded = torch.gather(encoded, 0, indices.unsqueeze(-1).expand_as(encoded))
            # --------------------------------------------------
            encoded = stack(encoded, xt.shape[-2], -2)
            py = self.dec(encoded, xt)
            yt = py.rsample()
            # xc = torch.cat([init_xc, xt], dim=-2)
            # yc = torch.cat([init_yc, yt], dim=-2)
            xc = xt
            yc = yt
        return py
    def predict_v2(self, xc, yc, xt, z=None, num_samples=None):
        xt = stack(xt, num_samples)
        xc = stack(xc, num_samples)
        yc = stack(yc, num_samples)
        init_xt= xt.clone()
        init_xc = xc.clone()
        init_yc = yc.clone()
        iters = 2
        samplez = False
        for i in range(iters):
            theta = self.denc(xc, yc)
            if z is None or samplez:
                pz = self.lenc(xc, yc)
                z = pz.rsample()
                samplez = True
            encoded = torch.cat([theta, z], -1) # (50, 16, 256)
            encoded = stack(encoded, xt.shape[-2], -2)
            if i != iters-1:
                sh_idx = torch.randperm(xt.shape[-2])
                shuffled = init_xt[...,sh_idx,:].clone()
                xt = 0.5*(init_xt+shuffled)
            else:
                xt = init_xt
            py = self.dec(encoded, xt)
            yt = py.rsample()
            xc = torch.cat([xc, xt], dim=-2)
            yc = torch.cat([yc, yt], dim=-2)
        return py
    def predict_v1(self, xc, yc, xt, z=None, num_samples=None):
        xt = stack(xt, num_samples)
        xc = stack(xc, num_samples)
        yc = stack(yc, num_samples)
        init_xc = xc.clone()
        init_yc = yc.clone()
        samplez = False
        for _ in range(2):
            theta = self.denc(xc, yc)
            if z is None or samplez:
                pz = self.lenc(xc, yc)
                z = pz.rsample()
                samplez=True
            encoded = torch.cat([theta, z], -1) # (50, 16, 256)
            # --------------------------------------------------
            # encoded filtering
            # --------------------------------------------------
            _encoded = stack(encoded, init_xc.shape[-2], -2)  # (50, 16, 26, 256)
            py = self.dec(_encoded, init_xc)
            recon = py.log_prob(init_yc).sum(-1).sum(-1)
            score = torch.moveaxis(recon, 0, -1)
            dist = Categorical(logits=score)
            indices = dist.sample([score.size(-1)]) # (50,16)
            encoded = torch.gather(encoded, 0, indices.unsqueeze(-1).expand_as(encoded))
            # --------------------------------------------------
            encoded = stack(encoded, xt.shape[-2], -2)
            py = self.dec(encoded, xt)
            yt = py.rsample()
            xc = torch.cat([init_xc, xt], dim=-2)
            yc = torch.cat([init_yc, yt], dim=-2)
        return py
    def predict_v0(self, xc, yc, xt, z=None, num_samples=None):
        theta = stack(self.denc(xc, yc), num_samples)
        if z is None:
            pz = self.lenc(xc, yc)
            z = pz.rsample() if num_samples is None \
                    else pz.rsample([num_samples])
        encoded = torch.cat([theta, z], -1)
        encoded = stack(encoded, xt.shape[-2], -2)
        return self.dec(encoded, stack(xt, num_samples))

    def sample(self, xc, yc, xt, z=None, num_samples=None):
        pred_dist = self.predict(xc, yc, xt, z, num_samples)
        return pred_dist.loc

    def forward(self, batch, num_samples=None, reduce_ll=True, iters=10, avg=False, **kwargs):
        outs = AttrDict()
        if self.training:
            pz = self.lenc(batch.xc, batch.yc)
            qz = self.lenc(batch.x, batch.y)
            z = qz.rsample() if num_samples is None else \
                    qz.rsample([num_samples])
            py = self.predict(batch.xc, batch.yc, batch.x,
                    z=z, num_samples=num_samples)

            if num_samples > 1:
                # K * B * N
                recon = py.log_prob(stack(batch.y, num_samples)).sum(-1)
                # K * B
                log_qz = qz.log_prob(z).sum(-1)
                log_pz = pz.log_prob(z).sum(-1)

                # K * B
                log_w = recon.sum(-1) + log_pz - log_qz

                outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2]
            else:
                outs.recon = py.log_prob(batch.y).sum(-1).mean()
                outs.kld = kl_divergence(qz, pz).sum(-1).mean()
                outs.loss = -outs.recon + outs.kld / batch.x.shape[-2]

        else:
            py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples, iters=iters, avg=avg)
            # py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples, y=batch.y)
            if num_samples is None:
                ll = py.log_prob(batch.y).sum(-1)
            else:
                y = torch.stack([batch.y]*num_samples)
                if reduce_ll:
                    ll = logmeanexp(py.log_prob(y).sum(-1))
                else:
                    ll = py.log_prob(y).sum(-1)
            num_ctx = batch.xc.shape[-2]
            if reduce_ll:
                outs.ctx_ll = ll[...,:num_ctx].mean()
                outs.tar_ll = ll[...,num_ctx:].mean()
            else:
                outs.ctx_ll = ll[...,:num_ctx]
                outs.tar_ll = ll[...,num_ctx:]
        return outs