import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from latentplan.models.autoencoders import SymbolWiseTransformer
from latentplan.models.transformers import *
from latentplan.models.ein import EinLinear

# -----------------------------
# Low-level VQ ops (unchanged)
# -----------------------------
class VectorQuantization(Function):
    @staticmethod
    def forward(ctx, inputs, codebook):  # inputs: [..., D], codebook: [K, D]
        with torch.no_grad():
            D = codebook.size(1)
            in_sz = inputs.size()
            x = inputs.view(-1, D)
            cb_sqr = torch.sum(codebook ** 2, dim=1)                 # [K]
            x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)           # [N,1]
            # dist = ||x||^2 + ||e||^2 - 2 x·e
            distances = torch.addmm(cb_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
            _, idx_flat = torch.min(distances, dim=1)                # [N]
            indices = idx_flat.view(*in_sz[:-1])                     # [...]
            ctx.mark_non_differentiable(indices)
            return indices

    @staticmethod
    def backward(ctx, grad_output):
        raise RuntimeError("VectorQuantization is not differentiable.")

class VectorQuantizationStraightThrough(Function):
    @staticmethod
    def forward(ctx, inputs, codebook):  # codebook: Tensor [K,D]
        idx = vq(inputs, codebook)                       # [...]
        idx_flat = idx.view(-1)
        ctx.save_for_backward(idx_flat, codebook)
        ctx.mark_non_differentiable(idx_flat)
        codes = torch.index_select(codebook, dim=0, index=idx_flat).view_as(inputs)
        return codes, idx_flat

    @staticmethod
    def backward(ctx, grad_output, grad_indices):
        grad_inputs = grad_output.clone() if ctx.needs_input_grad[0] else None
        grad_codebook = None
        if ctx.needs_input_grad[1]:
            idx_flat, codebook = ctx.saved_tensors
            D = codebook.size(1)
            go = grad_output.contiguous().view(-1, D)
            grad_codebook = torch.zeros_like(codebook)
            grad_codebook.index_add_(0, idx_flat, go)
        return grad_inputs, grad_codebook

vq = VectorQuantization.apply
vq_st = VectorQuantizationStraightThrough.apply


# --------------------------------------------------------
# EMA codebook that supports residual + masked/unmasked
# --------------------------------------------------------
class VQEmbeddingEMA(nn.Module):
    """
    EMA VQ codebook with straight-through assignment.
    - assign_from: tensor used to compute indices (usually residual of masked or full)
    - update_from: tensor used to update EMA (must be the *matching residual* stream)
    """
    def __init__(self, K, D, decay=0.99, eps=1e-5):
        super().__init__()
        emb = torch.empty(K, D)
        emb.uniform_(-1./K, 1./K)
        self.decay = decay
        self.eps = eps
        self.register_buffer("embedding", emb)          # [K,D]
        self.register_buffer("ema_count", torch.ones(K))
        self.register_buffer("ema_w", emb.clone())

    @torch.no_grad()
    def _ema_update(self, indices, x_update):           # x_update: [...,D] residual stream for this level
        # Efficient per-code sums
        K, D = self.embedding.size()
        x_flat = x_update.reshape(-1, D)
        idx = indices.reshape(-1)
        counts = torch.bincount(idx, minlength=K).float()
        self.ema_count.mul_(self.decay).add_((1 - self.decay) * counts)

        dw = torch.zeros_like(self.ema_w)               # [K,D]
        dw.index_add_(0, idx, x_flat)
        self.ema_w.mul_(self.decay).add_((1 - self.decay) * dw)

        denom = self.ema_count.add(self.eps).unsqueeze(-1)
        self.embedding.data.copy_(self.ema_w / denom)   # in-place, keeps buffer registration

    def straight_through(self, x_assign, x_update):
        """
        x_assign: residual tensor used to compute indices + ST codes
        x_update: matching residual tensor used for EMA update
        Returns: z_q (ST), z_bar (non-ST), indices
        """
        z_q, idx_flat = vq_st(x_assign.contiguous(), self.embedding)   # z_q gets grad to x_assign
        z_q = z_q.contiguous()
        if self.training:
            self._ema_update(idx_flat.view(*x_assign.shape[:-1]), x_update)
        # non-ST codes (no grad to x_assign)
        z_bar = torch.index_select(self.embedding, 0, idx_flat).view_as(x_assign).contiguous()
        return z_q, z_bar, idx_flat.view(*x_assign.shape[:-1])


class VQEmbedding(nn.Module):
    """Non-EMA codebook; kept for completeness (not used if ma_update=True)."""
    def __init__(self, K, D):
        super().__init__()
        self.embedding = nn.Embedding(K, D)
        self.embedding.weight.data.uniform_(-1./K, 1./K)

    def straight_through(self, x_assign):
        z_q, idx_flat = vq_st(x_assign.contiguous(), self.embedding.weight.detach())
        z_q = z_q.contiguous()
        z_bar = torch.index_select(self.embedding.weight, 0, idx_flat).view_as(x_assign).contiguous()
        return z_q, z_bar, idx_flat.view(*x_assign.shape[:-1])


class VQStepWiseTransformer(nn.Module):
    def __init__(self, config, feature_dim):
        super().__init__()
        self.K = config.K
        self.latent_size = config.trajectory_embd
        self.embedding_dim = config.n_embd
        self.trajectory_length = config.block_size
        self.block_size = config.block_size
        self.observation_dim = feature_dim
        self.action_dim = config.action_dim
        self.transition_dim = config.transition_dim
        self.latent_step = config.latent_step
        self.state_conditional = config.state_conditional

        self.masking = getattr(config, "masking", "none")
        self.bottleneck = getattr(config, "bottleneck", "pooling")
        self.n_levels = getattr(config, "n_levels", 2)
        self.ma_update = getattr(config, "ma_update", True)
        self.assign_from = getattr(config, "assign_from", "masked")   # "masked" or "full"
        self.update_from = getattr(config, "update_from", "masked")     # "masked" or "full"
        self.commitment_beta = getattr(config, "commitment_beta", 1)

        # NEW: optionally append terminals into encoder input
        self.use_term_channel = getattr(config, "use_term_channel", True)
        self.encoder_input_dim = self.transition_dim

        self.encoder = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        self.decoder = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        self.pos_emb = nn.Parameter(torch.zeros(1, self.trajectory_length, self.embedding_dim))

        # project per-step inputs -> model width
        self.embed = nn.Linear(self.encoder_input_dim, self.embedding_dim)
        self.predict = nn.Linear(self.embedding_dim, self.transition_dim)
        self.cast_embed = nn.Linear(self.embedding_dim, self.latent_size)
        self.latent_mixing = nn.Linear(self.latent_size + self.observation_dim, self.embedding_dim)

        if self.bottleneck == "pooling":
            self.latent_pooling = nn.MaxPool1d(self.latent_step, stride=self.latent_step)
        elif self.bottleneck == "attention":
            self.latent_pooling = AsymBlock(config, self.trajectory_length // self.latent_step)
            self.expand = AsymBlock(config, self.trajectory_length)
        else:
            raise ValueError(f"Unknown bottleneck type {self.bottleneck}")

        self.ln_f = nn.LayerNorm(self.embedding_dim)
        self.drop = nn.Dropout(config.embd_pdrop)

        # residual codebooks
        # if self.ma_update:
        #     self.codebooks = nn.ModuleList([
        #         VQEmbeddingEMA(self.K, self.latent_size,
        #                        decay=getattr(config, "ema_decay", 0.99),
        #                        eps=getattr(config, "ema_eps", 1e-5))
        #         for _ in range(self.n_levels)
        #     ])
        # else:
        #     self.codebooks = nn.ModuleList([VQEmbedding(self.K, self.latent_size)
        #                                     for _ in range(self.n_levels)])
        # --- shared residual codebook across depths (paper setup) ---
        if self.ma_update:
            self.codebook = VQEmbeddingEMA(
                self.K, self.latent_size,
                decay=getattr(config, "ema_decay", 0.99),
                eps=getattr(config, "ema_eps", 1e-5)
            )
        else:
            self.codebook = VQEmbedding(self.K, self.latent_size)

        # optional: keep old behavior by cutting gradients between levels
        self.stop_grad_between_levels = getattr(config, "stop_grad_between_levels", False)

    def codebook_weight(self):
        """Return the tensor [K, latent_size] for tying with the AR prior."""
        if isinstance(self.codebook, VQEmbeddingEMA):
            return self.codebook.embedding
        else:
            return self.codebook.embedding.weight

    def _encode_core(self, enc_inputs):
        """enc_inputs shape: [B, T, encoder_input_dim]"""
        x = enc_inputs.to(dtype=torch.float32)
        B, T, _ = x.size()
        assert T <= self.block_size, "Block size exhausted."

        tok = self.embed(x)                              # [B,T,emb]
        pos = self.pos_emb[:, :T, :]
        h = self.drop(tok + pos)
        h = self.encoder(h)
        if self.bottleneck == "pooling":
            h = self.latent_pooling(h.transpose(1, 2)).transpose(1, 2)  # [B,T',emb]
        else:
            h = self.latent_pooling(h)                                   # [B,T',emb]
        z = self.cast_embed(h)                                           # [B,T',latent_size]
        return z

    def encode(self, joined_inputs, terminals=None):
        """
        Convenience wrapper when you want to call encode directly.
        If use_term_channel==True, terminals must be provided, and we will concat them.
        """
        if self.use_term_channel:
            assert terminals is not None, "terminals must be provided when use_term_channel=True"
            enc_inputs = torch.cat([joined_inputs, terminals], dim=2)
        else:
            enc_inputs = joined_inputs
        return self._encode_core(enc_inputs)

    def decode(self, latents, state):
        B, Tprime, _ = latents.shape
        state_flat = state.view(B, 1, -1).repeat(1, Tprime, 1)
        if not self.state_conditional:
            state_flat = torch.zeros_like(state_flat)

        x = torch.cat([state_flat, latents], dim=-1)
        x = self.latent_mixing(x)
        if self.bottleneck == "pooling":
            x = torch.repeat_interleave(x, self.latent_step, dim=1)
        else:
            x = self.expand(x)
        x = x + self.pos_emb[:, :x.shape[1]]
        x = self.decoder(x)
        x = self.ln_f(x)
        out = self.predict(x)                              # [B,T,transition_dim]
        out[:, :, -1] = torch.sigmoid(out[:, :, -1])       # terminals
        #print(out.shape, state.shape)
        out[:, :, 1:self.observation_dim+1] += state.view(B, 1, -1)
        #print(out.shape)
        return out

    def forward(self, joined_inputs, padding_vector, state, terminals=None):
        B, T, D = joined_inputs.size()
        # mask value (channel 0)
        feat_mask = torch.ones_like(joined_inputs)
        feat_mask[:, :, 0] = 0.0
        feat_mask[:, -2:, -1] = 0.0
        x_masked = joined_inputs * feat_mask

        # terminal padding into features

        if terminals is not None:
            padded = torch.as_tensor(padding_vector, dtype=torch.float32, device=x_masked.device).repeat(B, T, 1)
            tmask = (1 - terminals).repeat(1, 1, D)
            x_masked = x_masked * tmask + (1 - tmask) * padded

        # Build encoder inputs with/without terminal channel
        if self.use_term_channel:
            assert terminals is not None, "terminals must be provided when use_term_channel=True"
            enc_full   = torch.cat([joined_inputs, terminals], dim=2)
            enc_masked = torch.cat([x_masked,       terminals], dim=2)
        else:
            enc_full   = joined_inputs
            enc_masked = x_masked

        feat_full   = self._encode_core(enc_full)      # [B,T',D_lat]
        feat_masked = self._encode_core(enc_masked)    # [B,T',D_lat]

        # choose assignment/update streams
        assign_in = feat_masked if self.assign_from == "masked" else feat_full
        update_in = feat_full  if self.update_from  == "full"   else feat_masked

        # residual VQ
        # res_assign = assign_in
        # res_update = update_in
        # quantised, z_bars = [], []
        # for cb in self.codebooks:
        #     z_q, z_bar, _ = cb.straight_through(res_assign, res_update)
        #     quantised.append(z_q)
        #     z_bars.append(z_bar)
        #     res_assign = (res_assign - z_bar).detach()
        #     res_update = (res_update - z_bar).detach()
        #
        # latents_st = torch.stack(quantised, dim=0).sum(0)  # [B,T',D_lat]
        # joined_pred = self.decode(latents_st, state)
        # return joined_pred, quantised, feat_full, feat_masked


        # residual VQ (shared codebook, repeated n_levels times)
        res_assign = assign_in
        res_update = update_in
        quantised, z_bars = [], []
        for _ in range(self.n_levels):
            z_q, z_bar, _ = self.codebook.straight_through(res_assign, res_update)
            quantised.append(z_q)
            z_bars.append(z_bar)
            if self.stop_grad_between_levels:
                res_assign = (res_assign - z_bar).detach()
                res_update = (res_update - z_bar).detach()
            else:
                res_assign = (res_assign - z_bar)
                res_update = (res_update - z_bar)

        latents_st = torch.stack(quantised, dim=0).sum(0)  # [B,T',D_lat]
        joined_pred = self.decode(latents_st, state)
        # NOTE: now we also return z_bars for the partial-sum commitment
        return joined_pred, quantised, z_bars, feat_full, feat_masked



class VQContinuousVAE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.model = VQStepWiseTransformer(config, config.observation_dim)
        self.trajectory_embd = config.trajectory_embd
        self.vocab_size = config.vocab_size
        self.block_size = config.block_size
        self.observation_dim = config.observation_dim
        self.masking = getattr(config, "masking", "none")
        self.action_dim = config.action_dim
        self.trajectory_length = config.block_size
        self.transition_dim = config.transition_dim
        self.action_weight = config.action_weight
        self.reward_weight = config.reward_weight
        self.value_weight = config.value_weight
        self.position_weight = config.position_weight
        self.first_action_weight = config.first_action_weight
        self.sum_reward_weight = config.sum_reward_weight
        self.last_value_weight = config.last_value_weight
        self.latent_step = config.latent_step
        self.padding_vector = torch.zeros(self.transition_dim)
        self.apply(self._init_weights)

    def get_block_size(self): return self.block_size
    def set_padding_vector(self, padding): self.padding_vector = padding

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_(); module.weight.data.fill_(1.0)

    def configure_optimizers(self, train_config):
        decay, no_decay = set(), set()
        whitelist = (torch.nn.Linear, EinLinear)
        blacklist = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = f"{mn}.{pn}" if mn else pn
                if pn.endswith("bias"): no_decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, whitelist): decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, blacklist): no_decay.add(fpn)
        if isinstance(self.model, (SymbolWiseTransformer, VQStepWiseTransformer)):
            no_decay.add('model.pos_emb')
            if self.model.bottleneck == "attention":
                no_decay.update({'model.latent_pooling.query', 'model.expand.query',
                                 'model.latent_pooling.attention.in_proj_weight',
                                 'model.expand.attention.in_proj_weight'})
        param_dict = {pn: p for pn, p in self.named_parameters()}
        assert len(decay & no_decay) == 0
        assert len(param_dict.keys() - (decay | no_decay)) == 0
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(decay)], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(no_decay)], "weight_decay": 0.0},
        ]
        return torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)

    @torch.no_grad()
    def encode(self, joined_inputs, terminals=None):
        """
        Returns indices for each residual level: [L, B, T'].
        """
        B, T, D = joined_inputs.size()
        # mask channel 0 for encoder invariance
        mask = torch.ones_like(joined_inputs)
        mask[:, :, 0] = 0.0
        #print("joined_inputs", joined_inputs.shape)
        mask[:, -2:, -1] = 0.0
        xin = joined_inputs * mask

        # terminal padding for features (keeps context clean beyond episode end)
        if terminals is not None:
            padded = torch.as_tensor(self.padding_vector, dtype=torch.float32, device=xin.device).repeat(B, T, 1)
            tmask = (1 - terminals).repeat(1, 1, D)
            xin = xin * tmask + (1 - tmask) * padded

        # Build encoder input (concat terminals if requested by model)
        if self.model.use_term_channel:
            assert terminals is not None, "terminals required when use_term_channel=True"
            enc_in = torch.cat([xin, terminals], dim=2)
        else:
            enc_in = xin

        # run per-level NN assignments on masked (or full) stream based on config.assign_from
        feat_full   = self.model._encode_core(torch.cat([joined_inputs, terminals], dim=2) if self.model.use_term_channel else joined_inputs) if (self.model.assign_from == "full") else None
        feat_masked = self.model._encode_core(enc_in) if (self.model.assign_from == "masked") else None
        assign_in = feat_masked if self.model.assign_from == "masked" else feat_full

        # idxs = []
        # res = assign_in
        # for cb in self.model.codebooks:
        #     emb = cb.embedding if isinstance(cb, VQEmbeddingEMA) else cb.embedding.weight
        #     idx = vq(res, emb)  # [B,T']
        #     idxs.append(idx)
        #     z_bar = F.embedding(idx.view(-1), emb).view_as(res)
        #     res = res - z_bar
        emb = self.model.codebook_weight()  # [K, D_lat]
        idxs = []
        res = assign_in
        for _ in range(self.model.n_levels):
            idx = vq(res, emb)  # [B,T']
            idxs.append(idx)
            z_bar = F.embedding(idx.view(-1), emb).view_as(res)
            res = res - z_bar
        return torch.stack(idxs, dim=0)  # [L,B,T']
        return torch.stack(idxs, dim=0)

    @torch.no_grad()
    def encode_runtime(self, joined_inputs, terminals=None):
        """
        Returns indices for each residual level: [L, B, T'].
        """
        B, T, D = joined_inputs.size()
        # mask channel 0 for encoder invariance
        mask = torch.ones_like(joined_inputs)
        mask[:, :, 0] = 0.0
        #print("joined_inputs", joined_inputs.shape)
        #mask[:, -2:, -1] = 0.0
        xin = joined_inputs * mask

        # terminal padding for features (keeps context clean beyond episode end)
        if terminals is not None:
            padded = torch.as_tensor(self.padding_vector, dtype=torch.float32, device=xin.device).repeat(B, T, 1)
            tmask = (1 - terminals).repeat(1, 1, D)
            xin = xin * tmask + (1 - tmask) * padded

        # Build encoder input (concat terminals if requested by model)
        if self.model.use_term_channel:
            assert terminals is not None, "terminals required when use_term_channel=True"
            enc_in = torch.cat([xin, terminals], dim=2)
        else:
            enc_in = xin

        # run per-level NN assignments on masked (or full) stream based on config.assign_from
        feat_full   = self.model._encode_core(torch.cat([joined_inputs, terminals], dim=2) if self.model.use_term_channel else joined_inputs) if (self.model.assign_from == "full") else None
        feat_masked = self.model._encode_core(enc_in) if (self.model.assign_from == "masked") else None
        assign_in = feat_masked if self.model.assign_from == "masked" else feat_full

        # idxs = []
        # res = assign_in
        # for cb in self.model.codebooks:
        #     emb = cb.embedding if isinstance(cb, VQEmbeddingEMA) else cb.embedding.weight
        #     idx = vq(res, emb)  # [B,T']
        #     idxs.append(idx)
        #     z_bar = F.embedding(idx.view(-1), emb).view_as(res)
        #     res = res - z_bar
        emb = self.model.codebook_weight()  # [K, D_lat]
        idxs = []
        res = assign_in
        for _ in range(self.model.n_levels):
            idx = vq(res, emb)  # [B,T']
            idxs.append(idx)
            z_bar = F.embedding(idx.view(-1), emb).view_as(res)
            res = res - z_bar
        return torch.stack(idxs, dim=0)  # [L,B,T']


    @torch.no_grad()
    def encode_with_soft(self, joined_inputs, terminals=None, tau: float = 0.5,
                         sample_from_soft: bool = True, topk: int | None = None):
        """
        Returns:
          indices: Long [L,B,T'] (sampled from Q_tau if sample_from_soft else NN)
          soft_targets: Float [L,B,T',K] (temperature-softened categorical over codes)
        """
        # Build encoder input exactly like in encode()
        B, T, D = joined_inputs.size()
        mask = torch.ones_like(joined_inputs); mask[:, :, 0] = 0.0
        xin = joined_inputs * mask
        if terminals is not None:
            padded = torch.as_tensor(self.padding_vector, dtype=torch.float32, device=xin.device).repeat(B, T, 1)
            tmask = (1 - terminals).repeat(1, 1, D)
            xin = xin * tmask + (1 - tmask) * padded

        if self.model.use_term_channel:
            assert terminals is not None, "terminals required when use_term_channel=True"
            enc_in = torch.cat([xin, terminals], dim=2)
        else:
            enc_in = xin

        feat = self.model._encode_core(enc_in)  # [B,T',D_lat]
        emb = self.model.codebook_weight()       # [K,D_lat]
        K = emb.size(0)

        indices = []
        soft_all = []
        res = feat
        for _ in range(self.model.n_levels):
            # distances: [B,T',K] using ||x||^2 + ||e||^2 - 2 x·e
            x = res.view(-1, res.size(-1))                # [B*T', D_lat]
            e = emb                                      # [K, D_lat]
            x_sqr = (x ** 2).sum(dim=1, keepdim=True)     # [B*T',1]
            e_sqr = (e ** 2).sum(dim=1).view(1, -1)       # [1,K]
            d = x_sqr + e_sqr - 2.0 * (x @ e.t())         # [B*T',K]
            logits = (-d / (2.0 * max(tau, 1e-8))).view(res.size(0), res.size(1), K)  # [B,T',K]

            # optionally restrict support for memory/compute
            if topk is not None and topk < K:
                topv, topi = torch.topk(logits, topk, dim=-1)
                mask_logits = torch.full_like(logits, -float("inf"))
                mask_logits.scatter_(-1, topi, topv)
                logits = mask_logits

            soft = torch.softmax(logits, dim=-1)  # [B,T',K]
            soft_all.append(soft)

            if sample_from_soft:
                idx = torch.multinomial(soft.view(-1, K), num_samples=1).view(res.size(0), res.size(1))  # [B,T']
            else:
                # nearest neighbor
                idx = torch.argmax(soft, dim=-1)  # [B,T']

            indices.append(idx)

            z_bar = F.embedding(idx.view(-1), emb).view_as(res)
            res = res - z_bar

        return torch.stack(indices, dim=0), torch.stack(soft_all, dim=0)


    def decode(self, latent, state):
        return self.model.decode(latent, state)

    def decode_from_indices(self, indices, state):
        if indices.dim() == 2:
            indices = indices.unsqueeze(0)  # [1,B,T']
        L, B, Tprime = indices.shape
        Dlat = self.trajectory_embd
        latent = None
        for l in range(L):
            cb = self.model.codebook
            emb = cb.embedding if isinstance(cb, VQEmbeddingEMA) else cb.embedding.weight
            z = F.embedding(indices[l].reshape(-1), emb).view(B, Tprime, Dlat)
            latent = z if latent is None else (latent + z)

        if self.model.bottleneck == "attention":
            pad_T = self.trajectory_length // self.latent_step
            latent = torch.cat([latent, torch.zeros(B, pad_T, latent.size(-1), device=latent.device)], dim=1)
        return self.model.decode(latent, state[:, None, :])

    def forward(self, joined_inputs, targets=None, mask=None, terminals=None, returnx=False):
        joined_inputs = joined_inputs.to(dtype=torch.float32)
        x = joined_inputs
        B, T, D = x.size()

        # terminal padding into features
        padded = torch.as_tensor(self.padding_vector, dtype=torch.float32, device=x.device).repeat(B, T, 1)
        if terminals is not None:
            tmask = (1 - terminals).repeat(1, 1, D)
            x = x * tmask + (1 - tmask) * padded

        # if mask is None:
        #     mask = torch.ones_like(x)
        # mask[:, :, 0] = 0.0  # don't train on value channel by default
        # choose conditioning state (here: from the penultimate step)
        state = x[:, -2, 1:(self.observation_dim + 1)]

        # >>>>>>>>>>>>>  pass terminals through  <<<<<<<<<<<<<<
        #reconstructed, quantised, feat_full, feat_masked = self.model(x, state, terminals=terminals)
        reconstructed, quantised, z_bars, feat_full, feat_masked = self.model(x, self.padding_vector, state, terminals=terminals)
        pred_traj  = reconstructed[:, :, :-1]          # [B,T,D]
        pred_terms = reconstructed[:, :, -1, None]

        # ----- losses -----
        #current_state_loss = next_state_loss = value_loss = first_action_loss = torch.tensor(0.0, device=x.device)
        #reconstruction_loss = loss_vq = loss_commit = None

        if targets is not None:
            weights = torch.cat([
                torch.ones(1, device=x.device) * self.value_weight,
                torch.ones(self.observation_dim, device=x.device),
                torch.ones(self.action_dim, device=x.device) * self.action_weight,
                torch.ones(self.action_dim, device=x.device) * self.action_weight,
                torch.ones(self.action_dim, device=x.device) * self.action_weight,
                torch.ones(1, device=x.device) * self.value_weight,
            ])

            mse_tail = F.mse_loss(pred_traj[:, -2:, :], joined_inputs[:, -2:, :], reduction='none') * weights[None, None, :]
            mse_ctx  = F.mse_loss(pred_traj[:, :-2, :], joined_inputs[:, :-2, :], reduction='none') * weights[None, None, :]
            #print(mask.shape, mse_tail.shape, mse_ctx.shape)

            ce = 0.0
            if terminals is not None:
                ce = F.binary_cross_entropy(pred_terms, torch.clamp(terminals.float(), 0.0, 1.0))

            first_action_loss = F.mse_loss(
                joined_inputs[:, -2, (1+self.observation_dim):(1+self.observation_dim+self.action_dim)],
                pred_traj[:,  -2, (1+self.observation_dim):(1+self.observation_dim+self.action_dim)]
            )
            #value_loss = F.mse_loss(joined_inputs[:, -2:, 0].mean(dim=1), pred_traj[:, -2:, 0].mean(dim=1))
            value_cur = F.mse_loss(pred_traj[:, -2, 0], joined_inputs[:, -2, 0])
            value_next = F.mse_loss(pred_traj[:, -1, 0], joined_inputs[:, -1, 0])
            macro_loss = F.mse_loss(joined_inputs[:, -2, -1], pred_traj[:, -2, -1])
            #macro_cur = F.mse_loss(joined_inputs[:, -1, -1], pred_traj[:, -1, -1])
            current_state_loss = F.mse_loss(
                joined_inputs[:, -2, 1:(self.observation_dim+1)], pred_traj[:, -2, 1:(self.observation_dim+1)]
            )
            next_state_loss = F.mse_loss(
                joined_inputs[:, -1, 1:(self.observation_dim+1)], pred_traj[:, -1, 1:(self.observation_dim+1)]
            )

            if terminals is None:
                term_mask = torch.ones(B, T, 1, device=x.device, dtype=x.dtype)
            else:
                term_mask = 1 - terminals
            term_mask = term_mask.repeat(1, 1, D)

            reconstruction_loss = (mse_tail * mask[:, -2:, :] * term_mask[:, -2:, :]).mean() \
                                  + 0.1 * (mse_ctx * mask[:, :-2, :] * term_mask[:, :-2, :]).mean() \
                                  + ce \
                                  + value_cur + value_next + macro_loss + next_state_loss + first_action_loss

            # masked vs unmasked feature consistency
            #loss_vq = F.mse_loss(feat_full, feat_masked)
            loss_vq = 0
            #loss_vq = F.mse_loss(feat_masked, feat_masked)

            # residual commitment proxy with z_q
            # loss_commit = 0.0
            # res = feat_masked if self.model.assign_from == "masked" else feat_full
            # for z_q in quantised:
            #     loss_commit += F.mse_loss(res, z_q.detach())
            #     res = res - z_q.detach()
            #

            Z = feat_masked
            z_hat = 0.0
            loss_commit = 0.0
            for zb in z_bars:
                # use non-ST codes and stop-grad on the partial sums: sg[hat Z^(d)]
                z_hat = (z_hat + zb).detach()
                loss_commit += F.mse_loss(Z, z_hat)
            loss_commit = loss_commit / len(z_bars) * self.model.commitment_beta

            #loss_commit = loss_commit / len(quantised) * self.model.commitment_beta

        return reconstructed, reconstruction_loss, loss_vq, loss_commit, current_state_loss, next_state_loss, (value_next+value_cur+macro_loss), first_action_loss





class TransformerPrior(nn.Module):
    """
    AR prior over residual-quantized tokens with arbitrary residual depth D.
    Training-only (no soft labels). Teacher-forced across time and depth.
    """

    def __init__(self, config):
        super().__init__()
        self.K            = config.K
        self.n_embd       = config.n_embd
        self.n_layer      = config.n_layer
        self.block_size   = config.block_size         # max spatial context (positions)
        self.n_levels     = 2         # residual depth D
        self.embd_pdrop   = getattr(config, "embd_pdrop", 0.1)
        self.observation_dim = config.observation_dim

        # Positional/state embeddings
        self.pos_emb_T = nn.Parameter(torch.zeros(1, self.block_size, self.n_embd))  # spatial
        self.pos_emb_D = nn.Parameter(torch.zeros(1, self.n_levels,  self.n_embd))   # depth
        self.state_emb = nn.Linear(self.observation_dim, self.n_embd)
        self.start_ctx = nn.Parameter(torch.zeros(1, 1, self.n_embd))  # u_1
        self.drop = nn.Dropout(self.embd_pdrop)

        # Spatial trunk
        self.blocks = nn.Sequential(*[Block(config) for _ in range(self.n_layer)])
        self.ln_f   = nn.LayerNorm(self.n_embd)

        # Depth “head”: small MLP + shared linear to K
        self.depth_mlp = nn.Sequential(
            nn.Linear(self.n_embd, self.n_embd),
            nn.GELU(),
            nn.Linear(self.n_embd, self.n_embd),
        )
        self.head = nn.Linear(self.n_embd, self.K, bias=False)

        # Fallback embedding if you don't tie the codebook
        self.tok_emb_fallback = nn.Embedding(self.K, self.n_embd)

        # Tie-to-codebook plumbing (optional but recommended)
        self._tied_codebook = None
        self._code_to_model = None
        self._use_tied = False

        self.apply(self._init_weights)

    # ---------- init / optimizer ----------
    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.zeros_(m.bias); nn.init.ones_(m.weight)

    def configure_optimizers(self, train_config):
        decay, no_decay = set(), set()
        whitelist = (torch.nn.Linear, EinLinear)
        blacklist = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = f"{mn}.{pn}" if mn else pn
                if pn.endswith("bias"):
                    no_decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, whitelist):
                    decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, blacklist):
                    no_decay.add(fpn)
        no_decay.update({'pos_emb_T', 'pos_emb_D', 'start_ctx'})
        param_dict = {pn: p for pn, p in self.named_parameters()}
        assert len(decay & no_decay) == 0
        assert len(param_dict.keys() - (decay | no_decay)) == 0
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(decay)],   "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(no_decay)],"weight_decay": 0.0},
        ]
        return torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)

    # ---------- tie/untie codebook ----------
    def tie_codebook(self, codebook_weight: torch.Tensor):
        """Tie token embeddings to external codebook e(k) [K, D_lat]; learn proj if needed."""
        self._tied_codebook = codebook_weight
        D_lat = int(codebook_weight.size(1))
        self._code_to_model = nn.Linear(D_lat, self.n_embd, bias=False) if D_lat != self.n_embd else nn.Identity()
        self._use_tied = True

    def untie_codebook(self):
        self._tied_codebook = None
        self._code_to_model = None
        self._use_tied = False

    # ---------- embedding helpers ----------
    def _embed_ids(self, ids: torch.Tensor) -> torch.Tensor:
        """ids: Long [...]; returns Float [..., n_embd]."""
        if self._use_tied:
            e = F.embedding(ids, self._tied_codebook.to(ids.device))  # [..., D_lat]
            return self._code_to_model(e)                             # [..., n_embd]
        else:
            return self.tok_emb_fallback(ids)                         # [..., n_embd]

    def _embed_all_depths(self, indices_all: torch.Tensor) -> torch.Tensor:
        """indices_all: [D,B,T] -> emb_all: [D,B,T,n_embd]"""
        D, B, T = indices_all.shape
        return torch.stack([self._embed_ids(indices_all[d]) for d in range(D)], dim=0)

    # ---------- trunk over space ----------
    def _spatial_trunk(self, sum_depth: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
        """
        sum_depth: [B,T,n_embd] (sum of depth embeddings at each position)
        Returns h: [B, Tctx, n_embd] for Tctx = min(T, block_size)
        """
        B, T, _ = sum_depth.shape
        # shift right by 1: u_1=start token, u_t = sum_depth(t-1)
        #print(self.start_ctx.expand(B, -1, -1).shape, sum_depth[:, :-1, :].shape)
        #u_prev_full = torch.cat([self.start_ctx.expand(B, -1, -1), sum_depth[:, :-1, :]], dim=1)  # [B,T,n_embd]
        u_prev_full = sum_depth[:, :-1, :]
        # choose the last Tctx = min(T, block_size) steps
        Tctx = min(u_prev_full.size(1), self.block_size)
        u_prev = u_prev_full

        pos = self.pos_emb_T[:, :-1, :]
        st  = self.state_emb(state).unsqueeze(1)           # [B,1,n_embd]
        x = self.drop(u_prev + pos + st)
        x = self.blocks(x)
        h = self.ln_f(x)                                   # [B,Tctx,n_embd]
        return h

    # ---------- forward: CE over all depths ----------
    def forward(self,
                indices_all: torch.Tensor,   # Long [D,B,T]
                state: torch.Tensor,         # Float [B, obs_dim]
                depth_weights: torch.Tensor | None = None,  # optional [D] weights
                label_smoothing: float = 0.0):
        """
        Returns (loss, logits_all), where logits_all is a list of length D with shapes [B, Ttgt, K].
        Ttgt = min(T, block_size) - 1
        """
        D, B, T = indices_all.shape
        #print("indices_all", indices_all.shape)
        self.n_levels = D
        assert T >= 2

        # Pre-embed and form per-position sums
        emb_all  = self._embed_all_depths(indices_all)         # [D,B,T,n_embd]
        #print("emb_all", emb_all.shape)
        sum_depth= emb_all.sum(dim=0)                           # [B,T,n_embd]
        #print("sum_depth", sum_depth.shape)
        # Spatial trunk
        h = self._spatial_trunk(sum_depth, state)               # [B,Tctx,n_embd]
        #print(h.shape)
        #print("Tctx", Tctx)
        Ttgt = T - 1
        #print(Ttgt)
        #print("indices_all", indices_all.shape)
        if Ttgt <= 0:
            raise ValueError("Context too short; increase block_size or sequence length T.")

        # Align targets to the last Ttgt positions
        targets_all = indices_all[:, :, 1:]                 # [D,B,Ttgt]
        h_step = h[:, :, :]                                    # [B,Ttgt,n_embd]
        peD = self.pos_emb_D[:, :D, :].squeeze(0)               # [D,n_embd]

        # CE over depths
        losses = []
        logits_all = []
        for d in range(D):
            # partial sum at same position using gold depths < d
            if d == 0:
                partial = torch.zeros_like(h_step)
            else:
                partial = emb_all[:d, :, 1:, :].sum(dim=0)  # [B,Ttgt,n_embd]
            x_d = self.depth_mlp(h_step + peD[d].view(1,1,-1) + partial)
            logits_d = self.head(x_d)                           # [B,Ttgt,K]
            logits_all.append(logits_d)

            # hard CE (no soft labels)
            if label_smoothing > 0.0:
                V = logits_d.size(-1)
                with torch.no_grad():
                    true = torch.zeros_like(logits_d).scatter_(-1, targets_all[d].unsqueeze(-1), 1.0)
                    smooth = label_smoothing / V
                    target_dist = (1.0 - label_smoothing) * true + smooth
                logp = F.log_softmax(logits_d, dim=-1)
                loss_d = (-(target_dist * logp).sum(dim=-1)).mean()
            else:
                loss_d = F.cross_entropy(
                    logits_d.reshape(-1, self.K), targets_all[d].reshape(-1)
                )
            losses.append(loss_d)

        if depth_weights is None:
            loss = sum(losses) / float(D)
        else:
            w = depth_weights.to(logits_all[0].device).view(D)
            w = w / (w.sum() + 1e-8)
            loss = sum(w[d] * losses[d] for d in range(D))
        return loss, logits_all



    # =========================
    # SAMPLING HELPERS
    # =========================
    @torch.no_grad()
    def _embed_codes_to_model(self, ids: torch.Tensor) -> torch.Tensor:
        # convenience: same as _embed_ids but under @no_grad for sampling code
        if self._use_tied:
            e = F.embedding(ids, self._tied_codebook.to(ids.device))
            return self._code_to_model(e)
        else:
            return self.tok_emb_fallback(ids)

    @torch.no_grad()
    def _last_hidden_from_ctx(self, ctx_indices_all: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
        """
        ctx_indices_all: Long [D,B,Tctx]
        state:           Float [B,obs_dim]
        returns: h_last: Float [B, n_embd] for the NEXT position
        """
        D, B, Tctx = ctx_indices_all.shape
        # embed + sum across depths per position
        #print("ctx_indices_all", ctx_indices_all.shape)
        emb_all  = self._embed_all_depths(ctx_indices_all)     # [D,B,Tctx,n_embd]
        sum_depth= emb_all.sum(dim=0)                          # [B,Tctx,n_embd]
        # choose the last Tctx = min(T, block_size) steps
        Tctx = min(sum_depth.size(1), self.block_size)
        pos = self.pos_emb_T[:, :Tctx, :]
        st  = self.state_emb(state).unsqueeze(1)           # [B,1,n_embd]
        x = self.drop(sum_depth + pos + st)
        x = self.blocks(x)
        h = self.ln_f(x)                                   # [B,Tctx,n_embd]
        #return h
        #print(h.shape)
        h_last = h[:, -1, :]                                   # [B,n_embd]
        return h_last

    @torch.no_grad()
    def topk_next_stacks_no_replacement_multi(self,
                                              model,
                                              ctx_indices_all: torch.Tensor,  # [D,B,Tctx]
                                              state: torch.Tensor,  # [B,obs_dim]
                                              K_topk: int = 8,
                                              N_per_coarse: int = 4,
                                              coarse_temperature: float = 1.0,
                                              fine_temperature: float = 1.0,
                                              deeper_policy: str = "sample",  # or "argmax"
                                              topk_each: int | None = None,
                                              return_flat: bool = True,
                                              # uniqueness controls:
                                              unique: bool = False,
                                              oversample_factor: int = 3,
                                              strict_unique: bool = False):
        """
        Like your topk_next_stacks_no_replacement, but for each coarse Top-K token we draw
        N_per_coarse deeper stacks and (optionally) enforce uniqueness per coarse.

        Returns:
          if return_flat:
            next_ids:  Long  [D, B, K_topk * N_per_coarse]
            joint_p:   Float [B,     K_topk * N_per_coarse]
          else:
            next_ids:  Long  [D, B, K_topk, N_per_coarse]
            joint_p:   Float [B,     K_topk, N_per_coarse]
        """
        device = ctx_indices_all.device
        D, B, Tctx = ctx_indices_all.shape
        peD = self.pos_emb_D[:, :D, :].squeeze(0)  # [D, n_embd]

        # 1) spatial context
        h_last = self._last_hidden_from_ctx(ctx_indices_all, state)  # [B, n_embd]

        # 2) coarse logits and Top-K (no replacement)
        x1 = self.depth_mlp(h_last + peD[0].view(1, -1))  # [B, n_embd]
        logits0 = self.head(x1) / max(1e-8, float(coarse_temperature))  # [B, Kv]
        logp0 = F.log_softmax(logits0, dim=-1)  # [B, Kv]
        k = min(K_topk, logits0.size(-1))
        _, topi = torch.topk(logits0, k=k, dim=-1)  # [B, K]
        coarse_ids = topi  # [B, K]
        coarse_lp = logp0.gather(-1, coarse_ids)  # [B, K]

        # Degenerate D==1 case: only coarse token
        if D == 1:
            if unique:
                # There are at most K unique coarse tokens; duplicate handling is trivial here.
                pass
            if return_flat:
                next_ids = coarse_ids.unsqueeze(0).repeat_interleave(N_per_coarse, dim=2)  # [1,B,K*N]
                joint_lp = coarse_lp.unsqueeze(-1).expand(B, K_topk, N_per_coarse).reshape(B, K_topk * N_per_coarse)
                return next_ids, torch.exp(joint_lp).clamp_max(1.0)
            else:
                next_ids = coarse_ids.unsqueeze(0).unsqueeze(-1).expand(1, B, K_topk, N_per_coarse)  # [1,B,K,N]
                joint_lp = coarse_lp.unsqueeze(-1).expand(B, K_topk, N_per_coarse)
                return next_ids, torch.exp(joint_lp).clamp_max(1.0)

        # 3) Expand per coarse and sample deeper depths *with replacement*, but oversample
        M_total = N_per_coarse if not unique else (N_per_coarse * max(1, oversample_factor))
        BK = B * K_topk
        BKM = BK * M_total

        h_rep = (h_last.unsqueeze(1).expand(B, K_topk, -1).reshape(BK, -1)
                 .unsqueeze(1).expand(BK, M_total, -1).reshape(BKM, -1))  # [B*K*M, n_embd]
        coarse_rep = coarse_ids.unsqueeze(-1).expand(B, K_topk, M_total).reshape(-1)  # [B*K*M]
        partial = self._embed_codes_to_model(coarse_rep)  # [B*K*M, n_embd]

        ids_per_depth = [coarse_rep.clone()]  # list of [B*K*M]
        lp_per_depth = [coarse_lp.unsqueeze(-1).expand(B, K_topk, M_total).reshape(-1)]  # [B*K*M]

        for d in range(1, D):
            xd = self.depth_mlp(h_rep + peD[d].view(1, -1) + partial)  # [B*K*M, n_embd]
            logitsd = self.head(xd) / max(1e-8, float(fine_temperature))  # [B*K*M, Kv]
            #logitsd = self.head(xd) / max(1e-8, 0.7)
            if topk_each is not None and topk_each > 0:
                k2 = min(topk_each, logitsd.size(-1))
                v, i = torch.topk(logitsd, k=k2, dim=-1)
                masked = torch.full_like(logitsd, -float("inf"))
                masked.scatter_(dim=-1, index=i, src=v)
                logitsd = masked
            if deeper_policy == "sample":
                pd = torch.softmax(logitsd, dim=-1)
                ids_d = torch.multinomial(pd, num_samples=1, replacement=False).squeeze(-1)  # [B*K*M]
                lp_d = torch.log(pd.gather(-1, ids_d.unsqueeze(-1)).squeeze(-1) + 1e-12)
            else:
                ids_d = torch.argmax(logitsd, dim=-1)
                lp_d = F.log_softmax(logitsd, dim=-1).gather(-1, ids_d.unsqueeze(-1)).squeeze(-1)

            ids_per_depth.append(ids_d)
            lp_per_depth.append(lp_d)
            partial = partial + self._embed_codes_to_model(ids_d)

        next_ids_all = torch.stack(ids_per_depth, dim=0).view(D, B, K_topk, M_total)  # [D,B,K,M_total]
        joint_lp_all = torch.stack(lp_per_depth, dim=0).sum(dim=0).view(B, K_topk, M_total)  # [B,K,M_total]

        if not unique:
            # Return first N_per_coarse samples per coarse as-is
            next_ids = next_ids_all[:, :, :, :N_per_coarse]  # [D,B,K,N]
            joint_lp = joint_lp_all[:, :, :N_per_coarse]  # [B,K,N]
        else:
            # 4) Per (b,k) keep top-N unique stacks by joint log-prob
            next_ids = torch.empty(D, B, K_topk, N_per_coarse, dtype=next_ids_all.dtype, device=device)
            joint_lp = torch.empty(B, K_topk, N_per_coarse, dtype=joint_lp_all.dtype, device=device)

            for b in range(B):
                for k_idx in range(K_topk):
                    stacks = next_ids_all[:, b, k_idx, :]  # [D, M_total]
                    scores = joint_lp_all[b, k_idx, :]  # [M_total]

                    order = torch.argsort(scores, descending=True)
                    stacks = stacks[:, order]  # [D, M_total]
                    scores = scores[order]  # [M_total]

                    # Greedy top-N unique by scanning in score order
                    picked = []
                    seen = set()
                    # NOTE: D is small (e.g., 2..4), so tolist() is fine
                    for m in range(stacks.size(1)):
                        key = tuple(stacks[:, m].tolist())
                        if key not in seen:
                            seen.add(key)
                            picked.append(m)
                            if len(picked) == N_per_coarse:
                                break

                    n_found = len(picked)
                    if n_found < N_per_coarse:
                        if strict_unique:
                            raise RuntimeError(f"Only {n_found} unique stacks available for (b={b}, k={k_idx}). "
                                               f"Increase oversample_factor or loosen constraints.")
                        # pad by repeating the best so shapes stay consistent
                        if n_found == 0:
                            picked = [0]
                            n_found = 1
                        pad = N_per_coarse - n_found
                        sel = torch.tensor(picked, device=device, dtype=torch.long)
                        next_ids[:, b, k_idx, :n_found] = stacks[:, sel]
                        joint_lp[b, k_idx, :n_found] = scores[sel]
                        next_ids[:, b, k_idx, n_found:] = stacks[:, sel[:1]].expand(D, pad)
                        joint_lp[b, k_idx, n_found:] = scores[sel[:1]].expand(pad)
                    else:
                        sel = torch.tensor(picked, device=device, dtype=torch.long)
                        next_ids[:, b, k_idx, :] = stacks[:, sel]
                        joint_lp[b, k_idx, :] = scores[sel]

        # Convert to probabilities and pack shapes as requested
        joint_p = torch.exp(joint_lp).clamp_max(1.0)
        if return_flat:
            return next_ids.view(D, B, K_topk * N_per_coarse), joint_p.view(B, K_topk * N_per_coarse)
        else:
            return next_ids, joint_p

    @torch.no_grad()
    def topk_next_stacks_no_replacement(self,
                                        model,
                                        ctx_indices_all: torch.Tensor,  # [D,B,Tctx]
                                        state: torch.Tensor,            # [B,obs_dim]
                                        K_topk: int = 8,
                                        temperature: float = 1.0,
                                        deeper_policy: str = "argmax",  # or "sample"
                                        topk_each: int | None = None):
        """
        Stage-1: Given context, propose K distinct next-position stacks (D tokens) WITHOUT replacement.
        Strategy: Top-K on coarse (depth 0), then greedily (or sampled) fill deeper depths
                  conditioned on each coarse candidate.
        Returns:
          next_ids:   Long  [D,B,K]   (stack per candidate)
          joint_logp: Float [B,K]     (sum of log-probs across depths)
        """
        device = ctx_indices_all.device
        D, B, Tctx = ctx_indices_all.shape
        peD = self.pos_emb_D[:, :D, :].squeeze(0)              # [D,n_embd]

        # ----- 1) spatial context -----
        h_last = self._last_hidden_from_ctx(ctx_indices_all, state)   # [B,n_embd]

        # ----- 2) depth-1 (coarse) logits and Top-K -----
        x1 = self.depth_mlp(h_last + peD[0].view(1, -1))             # [B,n_embd]
        logits0 = self.head(x1) / max(1e-8, float(temperature))      # [B,Kv]
        logp0 = F.log_softmax(logits0, dim=-1)                       # [B,Kv]
        topv, topi = torch.topk(logits0, k=min(K_topk, logits0.size(-1)), dim=-1)   # [B,K]
        coarse_ids = topi                                             # [B,K] (no replacement)
        coarse_lp  = logp0.gather(-1, coarse_ids)                     # [B,K]

        # ----- 3) expand to K candidates; fill deeper depths -----
        # Prepare batch of B*K rows
        BK = B * coarse_ids.size(1)
        h_rep = h_last.unsqueeze(1).expand(B, coarse_ids.size(1), self.n_embd).reshape(BK, self.n_embd)  # [B*K,D]
        c_flat = coarse_ids.reshape(-1)                                # [B*K]
        partial = self._embed_codes_to_model(c_flat)                   # [B*K,n_embd]

        all_ids = [c_flat.clone()]   # list of [B*K]
        all_lp  = [coarse_lp.reshape(-1)]  # list of [B*K]

        for d in range(1, D):
            xd = self.depth_mlp(h_rep + peD[d].view(1, -1) + partial)   # [B*K,n_embd]
            logitsd = self.head(xd) / max(1e-8, float(temperature))     # [B*K,Kv]
            if topk_each is not None and topk_each > 0:
                # optional pruning for speed
                k = min(topk_each, logitsd.size(-1))
                v, i = torch.topk(logitsd, k=k, dim=-1)
                masked = torch.full_like(logitsd, -float("inf"))
                masked.scatter_(dim=-1, index=i, src=v)
                logitsd = masked

            if deeper_policy == "sample":
                pd = torch.softmax(logitsd, dim=-1)
                ids_d = torch.multinomial(pd, num_samples=1, replacement=False).squeeze(-1)  # [B*K]
                lp_d  = torch.log(pd.gather(-1, ids_d.unsqueeze(-1)).squeeze(-1) + 1e-12)
            else:
                ids_d = torch.argmax(logitsd, dim=-1)                                       # [B*K]
                lp_d  = F.log_softmax(logitsd, dim=-1).gather(-1, ids_d.unsqueeze(-1)).squeeze(-1)

            all_ids.append(ids_d)
            all_lp.append(lp_d)
            partial = partial + self._embed_codes_to_model(ids_d)

        # stack and reshape back to [D,B,K]
        next_ids = torch.stack(all_ids, dim=0).reshape(D, B, -1)        # [D,B,K]
        joint_logp = torch.stack(all_lp, dim=0).sum(dim=0).reshape(B, -1)  # [B,K]
        return next_ids, torch.exp(joint_logp).clamp_max(1.0)

    @torch.no_grad()
    def sample_next_stacks_with_replacement(self,
                                            ctx_indices_all: torch.Tensor,  # [D,B,Tctx]
                                            state: torch.Tensor,            # [B,obs_dim]
                                            M_samples: int = 8,
                                            coarse_temperature: float = 1.0,
                                            fine_temperature: float = 1.0,
                                            deeper_policy: str = "argmax",
                                            topk_each: int | None = None):
        """
        Stage-2 (or generic one-step sampler): Given context, sample M next-position stacks WITH replacement.
        Coarse is sampled from softmax; deeper depths greedy by default (or sampled if requested).
        Returns:
          next_ids:   Long  [D,B,M]
          joint_logp: Float [B,M]
        """
        device = ctx_indices_all.device
        D, B, Tctx = ctx_indices_all.shape
        peD = self.pos_emb_D[:, :D, :].squeeze(0)                      # [D,n_embd]

        # 1) spatial context
        h_last = self._last_hidden_from_ctx(ctx_indices_all, state)    # [B,n_embd]
        # 2) coarse distribution and M samples (with replacement)
        x1 = self.depth_mlp(h_last + peD[0].view(1, -1))               # [B,n_embd]
        logits0 = self.head(x1) / max(1e-8, float(coarse_temperature))        # [B,Kv]
        p0 = torch.softmax(logits0, dim=-1)                            # [B,Kv]
        coarse_ids = torch.multinomial(p0, num_samples=M_samples, replacement=True)  # [B,M]
        coarse_lp  = torch.log(p0.gather(-1, coarse_ids) + 1e-12)                  # [B,M]

        # 3) expand B*M rows; fill deeper depths
        BM = B * M_samples
        h_rep = h_last.unsqueeze(1).expand(B, M_samples, self.n_embd).reshape(BM, self.n_embd)  # [B*M,D]
        c_flat = coarse_ids.reshape(-1)                                       # [B*M]
        partial = self._embed_codes_to_model(c_flat)                          # [B*M,n_embd]

        all_ids = [c_flat.clone()]     # list of [B*M]
        all_lp  = [coarse_lp.reshape(-1)]

        for d in range(1, D):
            xd = self.depth_mlp(h_rep + peD[d].view(1, -1) + partial)         # [B*M,n_embd]
            logitsd = self.head(xd) / max(1e-8, float(fine_temperature))           # [B*M,Kv]

            if topk_each is not None and topk_each > 0:
                k = min(topk_each, logitsd.size(-1))
                v, i = torch.topk(logitsd, k=k, dim=-1)
                masked = torch.full_like(logitsd, -float("inf"))
                masked.scatter_(dim=-1, index=i, src=v)
                logitsd = masked

            if deeper_policy == "sample":
                pd = torch.softmax(logitsd, dim=-1)
                ids_d = torch.multinomial(pd, num_samples=1, replacement=False).squeeze(-1)
                lp_d  = torch.log(pd.gather(-1, ids_d.unsqueeze(-1)).squeeze(-1) + 1e-12)
                # print("depth:", d, pd[0])
                # topk_vals, topk_idx = torch.topk(pd[0], k=10)
                # print("Top 10 values:", topk_vals)
                # print("Their indices:", topk_idx)
            else:
                ids_d = torch.argmax(logitsd, dim=-1)
                lp_d  = F.log_softmax(logitsd, dim=-1).gather(-1, ids_d.unsqueeze(-1)).squeeze(-1)

            all_ids.append(ids_d)
            all_lp.append(lp_d)
            partial = partial + self._embed_codes_to_model(ids_d)

        next_ids   = torch.stack(all_ids, dim=0).reshape(D, B, M_samples)     # [D,B,M]
        joint_logp = torch.stack(all_lp,  dim=0).sum(dim=0).reshape(B, M_samples)
        return next_ids, torch.exp(joint_logp).clamp_max(1.0)

    @torch.no_grad()
    def two_stage_expand(self,
                         model,
                         ctx_indices_all: torch.Tensor,  # [D,B,Tctx]
                         state: torch.Tensor,  # [B,obs_dim]
                         K_topk: int = 8,  # stage-1 breadth
                         N_per_coarse: int = 4,  # NEW: samples per coarse in stage-1
                         M_samples: int = 8,  # stage-2 per-branch samples
                         coarse_temperature_stage1: float = 1.0,
                         fine_temperature_stage1: float = 1.0,
                         coarse_temperature_stage2: float = 1.0,
                         fine_temperature_stage2: float = 1.0,
                         deeper_policy_stage1: str = "argmax",
                         deeper_policy_stage2: str = "argmax",
                         topk_each: int | None = None,
                         # shape controls
                         return_flat_stage1: bool = True,  # if True: stage1 -> [D,B,K*N]
                         return_flat_stage2: bool = True,  # if True: stage2 -> [D,B,K*N,M]
                         # uniqueness options for stage-1 (only used if you call *_multi)
                         unique_stage1: bool = True,
                         oversample_factor: int = 3):
        """
        Stage-1: Top-K coarse (no replacement) and for each coarse draw N_per_coarse stacks (depths 1..D-1).
        Stage-2: For each stage-1 candidate branch, sample M_samples next-step stacks WITH replacement.

        Returns:
          stage1:
            if return_flat_stage1:  next1_ids [D,B,K*N], next1_prob [B,K*N]
            else:                   next1_ids [D,B,K,N], next1_prob [B,K,N]
          stage2:
            if return_flat_stage2:  next2_ids [D,B,K*N,M], next2_logp [B,K*N,M]
            else:                   next2_ids [D,B,K,N,M], next2_logp [B,K,N,M]
        """
        D, B, Tctx = ctx_indices_all.shape

        # ---------- Stage 1 ----------
        if N_per_coarse == 1:
            # vanilla top-K (no replacement), one stack per coarse
            next1_ids, next1_prob = self.topk_next_stacks_no_replacement(
                model, ctx_indices_all, state,
                K_topk=K_topk,
                temperature=coarse_temperature_stage1,
                deeper_policy=deeper_policy_stage1,
                topk_each=topk_each
            )  # next1_ids: [D,B,K], next1_prob: [B,K] (probabilities)
            if return_flat_stage1:
                # flatten trivially
                pass  # already [D,B,K] and [B,K]
            else:
                # add a size-1 N dim for consistency
                next1_ids = next1_ids.unsqueeze(-1)  # [D,B,K,1]
                next1_prob = next1_prob.unsqueeze(-1)  # [B,K,1]
            K_eff = K_topk
        else:
            # multi per coarse (with optional uniqueness)
            next1_ids, next1_prob = self.topk_next_stacks_no_replacement_multi(
                model, ctx_indices_all, state,
                K_topk=K_topk,
                N_per_coarse=N_per_coarse,
                coarse_temperature=coarse_temperature_stage1,
                fine_temperature=fine_temperature_stage1,
                deeper_policy=deeper_policy_stage1,
                topk_each=topk_each,
                return_flat=return_flat_stage1,
                unique=unique_stage1,
                oversample_factor=oversample_factor
            )  # returns probs (not logp)
            K_eff = K_topk * N_per_coarse

        # ---------- Build extended contexts for stage 2 ----------
        if return_flat_stage1:
            # next1_ids: [D,B,K_eff]
            BK = B * K_eff
            ctx_rep = ctx_indices_all.repeat_interleave(K_eff, dim=1)  # [D, B*K_eff, Tctx]
            next1_flat = next1_ids.view(D, -1)  # [D, B*K_eff]
            ctx_ext = torch.cat([ctx_rep, next1_flat.unsqueeze(-1)], dim=2)  # [D, B*K_eff, Tctx+1]
            state_rep = state.unsqueeze(1).expand(B, K_eff, state.size(-1)).reshape(BK, -1)
        else:
            # next1_ids: [D,B,K,N]
            K, N = next1_ids.shape[2], next1_ids.shape[3]
            BK = B * K * N
            ctx_rep = (ctx_indices_all
                       .repeat_interleave(K * N, dim=1))  # [D, B*K*N, Tctx]
            next1_flat = next1_ids.view(D, -1)  # [D, B*K*N]
            ctx_ext = torch.cat([ctx_rep, next1_flat.unsqueeze(-1)], dim=2)  # [D, B*K*N, Tctx+1]
            state_rep = (state.unsqueeze(1)
                         .expand(B, K * N, state.size(-1))
                         .reshape(BK, -1))

        # ---------- Stage 2: sample M per branch ----------
        next2_ids, next2_logp = self.sample_next_stacks_with_replacement(
            ctx_ext, state_rep,
            M_samples=M_samples,
            coarse_temperature=coarse_temperature_stage2,
            fine_temperature=fine_temperature_stage2,
            deeper_policy=deeper_policy_stage2,
            topk_each=topk_each
        )  # next2_ids: [D, B*K_eff, M], next2_logp: [B*K_eff, M] (LOG-probs)

        # ---------- Reshape stage-2 outputs ----------
        if return_flat_stage1:
            if return_flat_stage2:
                # [D, B, K_eff, M], [B, K_eff, M]
                next2_ids = next2_ids.view(D, B, K_eff, M_samples)
                next2_logp = next2_logp.view(B, K_eff, M_samples)
            else:
                # split K_eff back into (K, N)
                K, N = K_topk, N_per_coarse
                next2_ids = next2_ids.view(D, B, K, N, M_samples)
                next2_logp = next2_logp.view(B, K, N, M_samples)
        else:
            # already have K,N grouping from stage-1
            if return_flat_stage2:
                next2_ids = next2_ids.view(D, B, K * N, M_samples)
                next2_logp = next2_logp.view(B, K * N, M_samples)
            else:
                next2_ids = next2_ids.view(D, B, K, N, M_samples)
                next2_logp = next2_logp.view(B, K, N, M_samples)

        return (next1_ids, next1_prob), (next2_ids, next2_logp)


    @torch.no_grad()
    def propose_fine_for_topk_coarse_and_select(
        self,
        vq_ae,                              # your VQContinuousVAE
        ctx_indices_all: torch.Tensor,      # [D,B,Tctx] context tokens (all depths)
        state: torch.Tensor,                # [B, obs_dim] conditioning state (same as training)
        K_coarse: int = 8,                  # Top-K coarse proposals (no replacement)
        M_fine: int = 8,                    # samples of deeper depths per coarse
        temperature_coarse: float = 1.0,    # temp for depth-1 logits
        temperature_fine: float = 1.0,      # temp for deeper depths
        fine_policy: str = "sample",        # "sample" or "argmax" for depths 2..D
        topk_each: int | None = None,       # optional per-depth top-k pruning for fine
        score_mode: str = "value_last",     # "value_last" | "value_mean"
        value_index: int = 0,               # channel index of value in decoder output
        terminal_index: int | None = -1,    # channel index of terminal flag in output, or None
        measure_tail: bool = True,          # for attention bottleneck: measure new-frame length
        chunk_decode: int | None = None,    # if not None, decode candidates in chunks to save VRAM
        return_all: bool = False            # if True, also return all candidates + scores
    ):
        """
        Returns:
          best_ids:    Long [D,B]          # chosen next-step stack (depth-major)
          best_score:  Float [B]           # score of the chosen candidate
          best_recon:  Float [B, T_tail, D_out]  # decoded T_tail frames for the chosen candidate
          (optional) extras dict if return_all=True
        """
        device = ctx_indices_all.device
        D, B, Tctx = ctx_indices_all.shape
        assert D >= 1, "Need at least one residual depth."
        peD = self.pos_emb_D[:, :D, :].squeeze(0)  # [D, n_embd]

        # ---------- 1) Spatial context for NEXT position ----------
        h_last = self._last_hidden_from_ctx(ctx_indices_all, state)  # [B, n_embd]

        # ---------- 2) Top-K coarse tokens (no replacement) ----------
        x1 = self.depth_mlp(h_last + peD[0].view(1, -1))             # [B, n_embd]
        logits0 = self.head(x1) / max(1e-8, float(temperature_coarse))  # [B, Kv]
        logp0   = F.log_softmax(logits0, dim=-1)
        k = min(K_coarse, logits0.size(-1))
        topv, topi = torch.topk(logits0, k=k, dim=-1)                # [B, K_coarse]
        coarse_ids = topi                                            # [B, K]
        coarse_lp  = logp0.gather(-1, coarse_ids)                    # [B, K]

        # ---------- 3) For each coarse, sample M deeper depths ----------
        # We will create B*K*M rows, each representing one candidate next-step stack.
        BK  = B * K_coarse
        BKM = BK * M_fine

        # Repeat h_last for (K*M) rows; embed coarse and repeat across M_fine
        h_rep = h_last.unsqueeze(1).expand(B, K_coarse, self.n_embd) \
                        .reshape(BK, self.n_embd)                    \
                        .unsqueeze(1).expand(BK, M_fine, self.n_embd) \
                        .reshape(BKM, self.n_embd)                   # [B*K*M, n_embd]

        coarse_flat = coarse_ids.unsqueeze(-1).expand(B, K_coarse, M_fine).reshape(-1)  # [B*K*M]
        partial = self._embed_ids(coarse_flat)                                         # [B*K*M, n_embd]

        # Keep track of ids/logp per depth
        ids_per_depth = [coarse_flat.clone()]     # list of [B*K*M]
        lp_per_depth  = [coarse_lp.unsqueeze(-1).expand(B, K_coarse, M_fine).reshape(-1)]  # [B*K*M]

        for d in range(1, D):
            xd = self.depth_mlp(h_rep + peD[d].view(1, -1) + partial)                   # [B*K*M, n_embd]
            logitsd = self.head(xd) / max(1e-8, float(temperature_fine))                # [B*K*M, Kv]

            if topk_each is not None and topk_each > 0:
                kk = min(topk_each, logitsd.size(-1))
                v, i = torch.topk(logitsd, k=kk, dim=-1)
                masked = torch.full_like(logitsd, -float('inf'))
                masked.scatter_(dim=-1, index=i, src=v)
                logitsd = masked

            if fine_policy == "sample":
                pd    = torch.softmax(logitsd, dim=-1)
                ids_d = torch.multinomial(pd, num_samples=1, replacement=True).squeeze(-1)   # [B*K*M]
                lp_d  = torch.log(pd.gather(-1, ids_d.unsqueeze(-1)).squeeze(-1) + 1e-12)
            else:  # greedy
                ids_d = torch.argmax(logitsd, dim=-1)                                        # [B*K*M]
                lp_d  = F.log_softmax(logitsd, dim=-1).gather(-1, ids_d.unsqueeze(-1)).squeeze(-1)

            ids_per_depth.append(ids_d)
            lp_per_depth.append(lp_d)
            partial = partial + self._embed_ids(ids_d)

        # Stack and reshape to [D, B, K, M]
        next_ids = torch.stack(ids_per_depth, dim=0).view(D, B, K_coarse, M_fine)
        joint_logp = torch.stack(lp_per_depth, dim=0).sum(dim=0).view(B, K_coarse, M_fine)  # [B,K,M]

        # ---------- 4) Build full sequences (context + 1 step) for ALL candidates ----------
        # [D, B*K*M, Tctx+1]
        ctx_rep = ctx_indices_all.repeat_interleave(K_coarse * M_fine, dim=1)   # [D, B*K*M, Tctx]
        #print(ctx_rep.shape)
        nxt_flat = next_ids.view(D, -1)                                         # [D, B*K*M]
        full_flat = torch.cat([ctx_rep, nxt_flat.unsqueeze(-1)], dim=2)         # [D, B*K*M, Tctx+1]

        # Repeat state across branches
        state_rep = state.unsqueeze(1).expand(B, K_coarse * M_fine, state.size(-1)) \
                         .reshape(B*K_coarse*M_fine, -1)                            # [B*K*M, obs_dim]

        # ---------- 5) Decode all candidates (causal decoder attends to context) ----------
        def _decode_in_chunks(indices_flat, state_flat, step: int | None):
            if chunk_decode is None:
                return vq_ae.decode_from_indices(indices_flat, state_flat)
            # chunking along batch dimension to save memory
            outs = []
            total = indices_flat.shape[1]
            bs = chunk_decode
            for s in range(0, total, bs):
                e = min(s + bs, total)
                outs.append(vq_ae.decode_from_indices(indices_flat[:, s:e, :], state_flat[s:e]))
            return torch.cat(outs, dim=0)

        # For attention bottleneck, we may want to measure how many *new* frames belong to the appended step.
        if measure_tail and getattr(vq_ae.model, "bottleneck", "pooling") != "pooling":
            base_dec = vq_ae.decode_from_indices(ctx_indices_all, state)        # [B, T_base, D_out]
            T_base = base_dec.size(1)
            # Decode one small chunk just to read total length; or decode all and reuse.
            recon_all = _decode_in_chunks(full_flat, state_rep, None)           # [B*K*M, T_full, D_out]
            T_full = recon_all.size(1)
            T_tail = max(1, T_full - T_base)
        else:
            # pooling path: each latent expands by latent_step
            recon_all = _decode_in_chunks(full_flat, state_rep, None)           # [B*K*M, T_full, D_out]
            T_tail = vq_ae.latent_step

        # ---------- 6) Score each candidate by a return proxy ----------
        # Take the last T_tail frames (the new step) and compute a score per row
        #print(recon_all.shape)
        tail = recon_all[:, -T_tail:, :]
        #print(tail.shape)# [B*K*M, T_tail, D_out]
        if score_mode == "value_mean":
            score_row = tail[:, :, value_index].mean(dim=1)                     # [B*K*M]
        else:  # "value_last"
            score_row = tail[:, -1, value_index]                                # [B*K*M]

        if terminal_index is not None:
            # Optional: downweight terminal predictions on the new step (if applicable)
            term_prob = tail[:, -1, terminal_index]
            # You can threshold or just penalize linearly. Here we subtract a penalty.
            score_row = score_row - 0.5 * term_prob

        # Reshape to [B,K,M], then pick best over K*M for each batch
        scores = score_row.view(B, K_coarse, M_fine)                             # [B,K,M]
        flat_idx = scores.view(B, -1).argmax(dim=-1)                             # [B]
        best_scores = scores.view(B, -1).gather(-1, flat_idx.unsqueeze(-1)).squeeze(-1)  # [B]

        # ---------- 7) Gather the best tokens and recon ----------
        # best next-step tokens: flatten K*M and index-select
        next_ids_flat = next_ids.view(D, B, K_coarse * M_fine)                   # [D,B,K*M]
        best_ids = next_ids_flat.gather(
            2, flat_idx.view(1, B, 1).expand(D, B, 1)
        ).squeeze(-1)                                                            # [D,B]

        # best recon rows: map (b, flat) -> row = b*(K*M) + flat
        rows = (torch.arange(B, device=device) * (K_coarse * M_fine) + flat_idx).long()
        best_recon = recon_all.index_select(0, rows)                             # [B, T_full, D_out]
        best_recon = best_recon[:, -T_tail:, :]                                  # keep just the new-step tail

        if return_all:
            extras = {
                "all_next_ids": next_ids,               # [D,B,K,M]
                "all_joint_logp": joint_logp,           # [B,K,M]
                "all_scores": scores,                   # [B,K,M]
                "all_recon_tail": tail.view(B, K_coarse, M_fine, T_tail, -1)  # [B,K,M,T_tail,D_out]
            }
            return best_ids, best_scores, best_recon, extras

        return best_ids, best_scores, best_recon
