import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from typing import Optional, Tuple, Dict, Literal, List, Any, Union
import math
import copy
from tqdm import tqdm

class AbductionTransformer(nn.Module):
    
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.kappa = encoder.nvib_layer.kappa
        self.delta = encoder.nvib_layer.delta
        self.prior_mu = encoder.nvib_layer.prior_mu
        self.prior_var = encoder.nvib_layer.prior_var
        self.prior_alpha = encoder.nvib_layer.prior_alpha
    
    def forward(
        self,
        pairs: torch.Tensor,
        dropout_eval: bool,
        mode: Literal["mean", "gradient_ascent"],
        gaussian_kl_coeff: Optional[float] = 1e-3,
        dirichlet_kl_coeff: Optional[float] = 1,
        **mode_kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        
        assert pairs.shape[-4] > 1, f"Number of pairs should be greater than 1, got {pairs.shape[-4]}."
        
        num_steps = mode_kwargs.get("num_steps", 1)
        lr = mode_kwargs.get("lr", 0.01)

        # First gather KL metrics if using a variational encoder
        metrics = {}
        nvib_outputs, kl_metrics = self._encode_pairs(pairs, dropout_eval)

        if kl_metrics:
            metrics.update(kl_metrics)
            prior_kl_loss = gaussian_kl_coeff * kl_metrics.get("gaussian_kl") + dirichlet_kl_coeff * kl_metrics.get("dirichlet_kl") 
        else:
            prior_kl_loss = None

        # Now implement the different modes
        if mode == "mean":
            # Leave-one-out evaluation with averaged denoising attention
            loss, mode_metrics = self._leave_one_out_loss(nvib_outputs, pairs, dropout_eval)
            metrics.update(mode_metrics)
        elif mode == "gradient_ascent":
            for arg in ["num_steps", "lr"]:
                assert arg in mode_kwargs, f"'{arg}' argument required for 'gradient_ascent' training mode"
            
            losses = []
            pair_metrics = []
            
            N = len(nvib_outputs)
            for i in range(N):
                context_indices = [j for j in range(N) if j != i]
                context_dicts = [nvib_outputs[j] for j in context_indices]

                context_pairs = pairs[..., context_indices, :, :]
                pair_to_eval = pairs[..., i, :, :]

                input_seq, true_output_seq = self._prepare_sequences_for_decoder(pair_to_eval, dropout_eval=dropout_eval)
                
                best_context, updated_decoder_params = self._get_gradient_ascent_context(
                    context_dicts, num_steps, lr,
                    context_pairs,
                    gaussian_kl_coeff=gaussian_kl_coeff,
                    dirichlet_kl_coeff=dirichlet_kl_coeff
                )
                if updated_decoder_params is not None:
                    grid_logits = torch.nn.utils.stateless.functional_call(
                        self.decoder, 
                        updated_decoder_params, 
                        (input_seq, true_output_seq, best_context)
                    )
                else:
                    grid_logits = self.decoder(
                        input_seq, true_output_seq, best_context
                    )
                
                loss_i, metrics_i = self._compute_loss(
                    grid_logits, 
                    true_output_seq
                )
                
                losses.append(loss_i)
                pair_metrics.append(metrics_i)
            
            loss = torch.stack(losses).mean()
            
            mode_metrics = {}
            for key in pair_metrics[0].keys():
                mode_metrics[key] = torch.mean(torch.stack([m[key] for m in pair_metrics]))
            
            metrics.update(mode_metrics)
        else:
            raise ValueError(f"Unsupported mode: {mode}")
            
        total_loss = loss
        if prior_kl_loss is not None and (gaussian_kl_coeff > 0 and dirichlet_kl_coeff > 0):
            total_loss += prior_kl_loss
            metrics["prior_kl_loss"] = prior_kl_loss
            
        metrics["total_loss"] = total_loss
        return total_loss, metrics
    
    def _encode_pairs(
        self, pairs: torch.Tensor, dropout_eval: bool
    ) -> Tuple[List[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]:
   
        base_dims = pairs.shape[0]  
        N, IO, Seq = pairs.shape[1:]  
        
        nvib_outputs = []

        for i in range(N):
            pair_slice = pairs[:, i, :, :]
            pair_slice = pair_slice.unsqueeze(1)

            nvib_dict = self.encoder(pair_slice)
            nvib_dict['original_base_shape'] = base_dims
            nvib_outputs.append(nvib_dict)
        
        kl_metrics = {}
        if 'mu' in nvib_outputs[0] and 'logvar' in nvib_outputs[0]:
            all_mu_list = []
            all_logvar_list = []
            all_alpha_list = []
            all_memory_key_padding_mask_list = []
            
            for i in range(N):
                mu = nvib_outputs[i]['mu']
                logvar = nvib_outputs[i]['logvar']
                alpha = nvib_outputs[i]['alpha']
                all_memory_key_padding_mask = nvib_outputs[i]['memory_key_padding_mask']

                mu = mu.transpose(0, 1)
                logvar = logvar.transpose(0, 1)
                alpha = alpha.transpose(0, 1)
            
                all_mu_list.append(mu)
                all_logvar_list.append(logvar)
                all_alpha_list.append(alpha)
                all_memory_key_padding_mask_list.append(all_memory_key_padding_mask)
            
            all_mu = torch.stack(all_mu_list, dim=1)
            all_logvar = torch.stack(all_logvar_list, dim=1)
            all_alpha = torch.stack(all_alpha_list, dim=1)
            all_memory_key_padding_mask = torch.stack(all_memory_key_padding_mask_list, dim=1)
            
            kl_loss = self.kl_gaussian(all_mu, all_logvar, all_alpha, all_memory_key_padding_mask)
            kl_metrics["gaussian_kl"] = kl_loss
            
            dirichlet_kl = self.kl_dirichlet(all_alpha, all_memory_key_padding_mask)
            kl_metrics["dirichlet_kl"] = dirichlet_kl
            
        return nvib_outputs, kl_metrics
    
    def _leave_one_out_loss(
        self, 
        nvib_outputs: List[Dict[str, torch.Tensor]], 
        pairs: torch.Tensor, 
        dropout_eval: bool
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        
        N = len(nvib_outputs)
        losses = []
        metrics_dict = {
            "grid_loss": [],
        }
        
        for i in range(N):
            context_indices = [j for j in range(N) if j != i]
            context_dicts = [nvib_outputs[j] for j in context_indices]
            
            pair_to_eval = pairs[..., i, :, :]
            
            input_seq, true_output_seq = self._prepare_sequences_for_decoder(pair_to_eval, dropout_eval=dropout_eval)

            context_dicts = self.merge_nvib_outputs(context_dicts)
            
            grid_logits = self.decoder(
                input_seq, true_output_seq, context_dicts
            )
        
            loss_i, pair_metrics = self._compute_loss(
                grid_logits, 
                true_output_seq
            )
            
            losses.append(loss_i)
            for k, v in pair_metrics.items():
                if k in metrics_dict:
                    metrics_dict[k].append(v)
        
        mean_loss = torch.stack(losses).mean()
        
        metrics = {}
        for k, v_list in metrics_dict.items():
            metrics[k] = torch.stack(v_list).mean()
        
        return mean_loss, metrics
    
    def _prepare_sequences_for_decoder(
        self, pair: torch.Tensor, dropout_eval: bool
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        
        flattened_pair = pair
        
        input_channel = flattened_pair[..., 0, :]
        output_channel = flattened_pair[..., 1, :]

        batch_dims = input_channel.shape[:-1]
        bos_token = torch.zeros(*batch_dims, 1, device=input_channel.device)

        input_seq = torch.cat([bos_token, input_channel], dim=-1)
        output_seq = output_channel
        return input_seq, output_seq
    
    def _compute_loss(
        self,
        grid_logits: torch.Tensor,
        output_seq: torch.Tensor,
    ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
        
        grid_targets = output_seq.long()
        
        flat_logits = grid_logits.reshape(-1, *grid_logits.shape[-2:])
        flat_logits = flat_logits.transpose(1, 2)
        
        if torch.isnan(flat_logits).any() or torch.isinf(flat_logits).any():
            problem_mask = torch.isnan(flat_logits) | torch.isinf(flat_logits)
            flat_logits = torch.where(problem_mask, torch.zeros_like(flat_logits), flat_logits)
        
        flat_logits = torch.clamp(flat_logits, min=-20.0, max=20.0)
        
        flat_targets = grid_targets.reshape(-1, grid_targets.shape[-1])
        
        num_classes = flat_logits.size(1)
        if (flat_targets >= num_classes).any() or (flat_targets < 0).any():
            flat_targets = torch.clamp(flat_targets, min=0, max=num_classes-1)
        try:
            flat_losses = torch.nn.functional.cross_entropy(
                flat_logits, 
                flat_targets, 
                reduction='none',
            )
            
            if torch.isnan(flat_losses).any() or torch.isinf(flat_losses).any():
                flat_losses = torch.where(
                    torch.isnan(flat_losses) | torch.isinf(flat_losses),
                    torch.ones_like(flat_losses),
                    flat_losses
                )
            
        except Exception as e:
            log_probs = torch.nn.functional.log_softmax(flat_logits, dim=1)
            flat_losses = -torch.gather(log_probs, 1, flat_targets.unsqueeze(1)).squeeze(1)
            
            flat_losses = torch.where(
                torch.isnan(flat_losses) | torch.isinf(flat_losses),
                torch.ones_like(flat_losses),
                flat_losses
            )
        
        grid_losses = flat_losses.reshape_as(grid_targets)
        
        grid_loss = torch.mean(grid_losses)
        if torch.isnan(grid_loss) or torch.isinf(grid_loss):
            grid_loss = torch.tensor(1.0, device=grid_losses.device)
        
        total_loss = grid_loss
        
        metrics = {
            "grid_loss": grid_loss,
            "total_loss": total_loss,
        }
        return total_loss, metrics

    def get_first_padded_logits(
        self,
        grid_logits: torch.Tensor,
        num_cols: torch.Tensor,
        max_rows: int,
        max_cols: int
    ) -> torch.Tensor:
        
        *batch_dims, seq_len, D = grid_logits.shape
        device = grid_logits.device

        row_ids = torch.arange(1, max_rows, device=device)

        num_cols = num_cols.view(*batch_dims, 1)
        first_pad_idx = max_cols * row_ids - (max_cols - num_cols)

        gather_idx = first_pad_idx.unsqueeze(-1).expand(*batch_dims, max_rows - 1, D)
        gather_idx = gather_idx.long()
        padded_logits = torch.gather(grid_logits, dim=-2, index=gather_idx)

        return padded_logits

    def copy_padded_logits_to_next_row_start(
        self,
        grid_logits: torch.Tensor,
        padded_logits: torch.Tensor,
        max_rows: int,
        max_cols: int
    ) -> torch.Tensor:
        
        *batch_dims, seq_len, D = grid_logits.shape
        B = int(torch.tensor(batch_dims).numel()) if batch_dims else 1
        device = grid_logits.device

        row_ids = torch.arange(1, max_rows, device=device)
        dest_pos = (row_ids * max_cols).view(*([1] * len(batch_dims)), -1)
        dest_pos = dest_pos.expand(*batch_dims, -1)

        dest_idx = dest_pos.unsqueeze(-1).expand_as(padded_logits)

        updated_grid = grid_logits.clone()
        updated_grid.scatter_(dim=-2, index=dest_idx, src=padded_logits)

        return updated_grid

    def gaussian_mixture_kl_divergence(self, mus, log_vars, alphas):

        batch_size, num_pairs, seq_len, hidden_dim = mus.shape
        
        alphas = alphas.squeeze(-1)
        alphas = torch.softmax(alphas, dim=2)
        
        vars = torch.exp(log_vars)
        
        mus_i = mus.unsqueeze(2)
        mus_j = mus.unsqueeze(1)
        
        vars_i = vars.unsqueeze(2)
        vars_j = vars.unsqueeze(1)
        
        alphas_i = alphas.unsqueeze(2)
        alphas_j = alphas.unsqueeze(1)
        
        log_det_term = 0.5 * (log_vars.unsqueeze(1) - log_vars.unsqueeze(2))
        
        mu_diff_squared = (mus_i - mus_j).pow(2)
        
        trace_term = 0.5 * (vars_i + mu_diff_squared) / (vars_j + 1e-8)
        
        component_kl = log_det_term + trace_term - 0.5
        component_kl = component_kl.sum(dim=-1)
        
        pairwise_weights = alphas_i * alphas_j
        
        weighted_kl = component_kl * pairwise_weights
        pairwise_kl = weighted_kl.sum(dim=-1)
        
        mask = 1.0 - torch.eye(num_pairs, device=mus.device).unsqueeze(0)
        
        masked_kl = pairwise_kl * mask
        
        kl_divs = masked_kl.sum(dim=(1, 2)) / (num_pairs * (num_pairs - 1))
        kl_divs = torch.mean(kl_divs) 
        
        return kl_divs
    
    
    def kl_gaussian(self, mu, logvar, alpha, memory_key_padding_mask, **kwargs):
        logvar = logvar.clamp(min=-10, max=10)

        if mu.dim() == 4:
            mu = mu.reshape(-1, *mu.shape[2:]).transpose(0, 1)
            logvar = logvar.reshape(-1, *logvar.shape[2:]).transpose(0, 1)
            alpha = alpha.reshape(-1, *alpha.shape[2:]).transpose(0, 1)
            memory_key_padding_mask = memory_key_padding_mask.reshape(-1, *memory_key_padding_mask.shape[2:]).transpose(0, 1)

        k0 = torch.sum(~memory_key_padding_mask, 0)
        k0 = k0.clamp(min=1)
        n = torch.clamp(k0 / self.kappa, min=1e-2)

        alpha = alpha.masked_fill(memory_key_padding_mask.unsqueeze(-1), 0)
        alpha0_q = torch.sum(alpha.transpose(2, 0), -1).clamp(min=1e-4)
        expected_pi = alpha.squeeze(-1) / alpha0_q

        var_ratio = logvar.exp().clamp(min=1e-8) / (self.prior_var + 1e-8)
        t1 = (mu - self.prior_mu) ** 2 / (self.prior_var + 1e-8)

        var_ratio_clamped = var_ratio.clamp(min=1e-8)
        kl = var_ratio_clamped + t1 - 1 - torch.log(var_ratio_clamped)

        kl = kl.masked_fill(memory_key_padding_mask.unsqueeze(-1), 0)
        kl = torch.mean(kl, -1)
        kl = 0.5 * k0 * torch.sum(kl * expected_pi, 0)
        kl = kl / n
        kl = torch.mean(kl)
        return kl

    def kl_dirichlet(self, alpha, memory_key_padding_mask, **kwargs):

        if alpha.dim() == 4:
            alpha = alpha.reshape(-1, *alpha.shape[2:])
            alpha = alpha.transpose(0, 1)
            memory_key_padding_mask = memory_key_padding_mask.reshape(-1, *memory_key_padding_mask.shape[2:])
            memory_key_padding_mask = memory_key_padding_mask.transpose(0, 1)

        alpha = alpha.clone()
        alpha[memory_key_padding_mask.unsqueeze(-1)] = 1e-4

        k0 = torch.sum(~memory_key_padding_mask, 0)
        k0 = k0.clamp(min=1)

        n = torch.clamp(k0 / self.kappa, min=1e-2)

        alpha0_q = torch.sum(alpha, 0).squeeze(-1).to(torch.float64)
        alpha0_q = alpha0_q.clamp(min=1e-2)

        alpha0_p = (torch.ones_like(alpha0_q) * (self.prior_alpha + self.delta * (n - 1))).to(torch.float64)
        alpha0_p = alpha0_p.clamp(min=1e-2)

        alpha0_q_k = torch.clamp(alpha0_q / k0, min=1e-2)
        alpha0_p_k = torch.clamp(alpha0_p / k0, min=1e-2)

        kl = (
            torch.lgamma(alpha0_q)
            - torch.lgamma(alpha0_p)
            + (alpha0_q - alpha0_p)
            * (-torch.digamma(alpha0_q) + torch.digamma(alpha0_q_k))
            + k0 * (torch.lgamma(alpha0_p_k) - torch.lgamma(alpha0_q_k))
        ) / n

        kl = torch.mean(kl)

        return kl.to(torch.float32)

    def merge_nvib_outputs(self, nvib_outputs: List[Dict[str, Any]]) -> Dict[str, Any]:

        if not nvib_outputs:
            return {}
        
        merged_dict = {}
        
        all_keys = set()
        for nvib_dict in nvib_outputs:
            all_keys.update(nvib_dict.keys())
        
        for key in all_keys:
            if key == 'original_base_shape':
                continue
                
            values = [nvib_dict.get(key) for nvib_dict in nvib_outputs if key in nvib_dict]
            if not values:
                continue
            
            if isinstance(values[0], torch.Tensor):
                shape = values[0].shape
                
                if len(shape) >= 2:
                    if key == 'memory_key_padding_mask':
                        merged_dict[key] = torch.cat(values, dim=1)
                    elif len(shape) >= 3 and shape[1] == values[1].shape[1]:
                        merged_dict[key] = torch.cat(values, dim=0)
                    else:
                        merged_dict[key] = torch.cat(values, dim=1)
                else:
                    merged_dict[key] = values[0]
            
            elif isinstance(values[0], tuple):
                tuple_elements = []
                
                num_elements = len(values[0])
                
                for i in range(num_elements):
                    tensors = [value[i] for value in values]
                    
                    if isinstance(tensors[0], torch.Tensor):
                        shape = tensors[0].shape
                        
                        if len(shape) >= 2:
                            if len(shape) >= 3 and shape[1] == tensors[1].shape[1]:
                                tuple_elements.append(torch.cat(tensors, dim=0))
                            else:
                                tuple_elements.append(torch.cat(tensors, dim=1))
                        else:
                            tuple_elements.append(tensors[0])
                    else:
                        tuple_elements.append(tensors[0])
                
                merged_dict[key] = tuple(tuple_elements)
            
            else:
                merged_dict[key] = values[0]
        
        if 'original_base_shape' in nvib_outputs[0]:
            base_shape = nvib_outputs[0]['original_base_shape']
            merged_dict['original_base_shape'] = base_shape
        
        return merged_dict

    def _get_gradient_ascent_context(
        self,
        nvib_outputs: List[Dict[str, torch.Tensor]],
        num_steps,
        lr,
        context_pairs,
        gaussian_kl_coeff=None,
        dirichlet_kl_coeff=None,
        update_decoder: bool = False,
    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
        """Find the best context using gradient ascent with surrogate loss."""
        base_context = self.merge_nvib_outputs(nvib_outputs)
        
        z_tuple = base_context['z']
        z, pi, mu, logvar = [tensor.clone().detach().requires_grad_(True) for tensor in z_tuple[:4]]
        alpha = base_context['alpha'].clone().detach()
        memory_key_padding_mask = base_context['memory_key_padding_mask'].clone().detach().transpose(0, 1)
        
        params_to_optimize = [z, pi, mu, logvar]
        optim_obj = optim.Adam(params_to_optimize, lr=lr)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim_obj, T_max=num_steps, eta_min=0)

        input_seq, true_output_seq = self._prepare_sequences_for_decoder(context_pairs, dropout_eval=False)
        
        param_snapshots = [{'z': z.clone(), 'pi': pi.clone(), 'mu': mu.clone(), 'logvar': logvar.clone()}]
        loss_history = []
        
        if update_decoder:
            original_decoder_state = {name: param.clone() for name, param in self.decoder.named_parameters()}
            temp_decoder = copy.deepcopy(self.decoder)
            optim_decoder = optim.Adam(temp_decoder.parameters(), lr=1e-4)
        else:
            temp_decoder = None
            optim_decoder = None

        decoder_param_snapshots = []

        for step in tqdm(range(num_steps)):
            optim_obj.zero_grad()
            if update_decoder:
                optim_decoder.zero_grad()
            
            current_context = base_context.copy()
            current_context['z'] = (z, pi, mu, logvar)
            if 'mu' in current_context: current_context['mu'] = mu
            if 'logvar' in current_context: current_context['logvar'] = logvar
            if 'pi' in current_context: current_context['pi'] = pi
            
            grid_logits = temp_decoder(
                input_seq, true_output_seq, current_context
            ) if update_decoder else self.decoder(
                input_seq, true_output_seq, current_context
            )
            
            loss, _ = self._compute_loss(grid_logits, true_output_seq)
            
            kl_loss = self.kl_gaussian(mu, logvar, pi, memory_key_padding_mask=memory_key_padding_mask)
            
            dirichlet_kl = self.kl_dirichlet(pi, memory_key_padding_mask=memory_key_padding_mask)

            loss_history.append(loss.detach().clone())
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(params_to_optimize, max_norm=1.0)
            with torch.no_grad():
                optim_obj.step()
                if update_decoder:
                    optim_decoder.step()
                lr_scheduler.step()
            
            param_snapshots.append({'z': z.clone(), 'pi': pi.clone(), 'mu': mu.clone(), 'logvar': logvar.clone()})
            if update_decoder:
                decoder_param_snapshots.append({name: param.clone() for name, param in temp_decoder.named_parameters()})
        
        all_losses = torch.stack(loss_history)
        best_idx, = torch.topk(all_losses, k=1, largest=False).indices
        best_snapshot = param_snapshots[best_idx]
        
        orig_z_tuple = base_context['z']
        orig_z, orig_pi, orig_mu, orig_logvar = orig_z_tuple[:4]

        orig_z = orig_z + (best_snapshot['z'] - orig_z).detach()
        orig_pi = orig_pi + (best_snapshot['pi'] - orig_pi).detach()
        orig_mu = orig_mu + (best_snapshot['mu'] - orig_mu).detach()
        orig_logvar = orig_logvar + (best_snapshot['logvar'] - orig_logvar).detach()
        
        updated_context = base_context.copy()
        updated_context['z'] = (orig_z, orig_pi, orig_mu, orig_logvar)
        if 'mu' in updated_context: updated_context['mu'] = orig_mu
        if 'logvar' in updated_context: updated_context['logvar'] = orig_logvar
        if 'pi' in updated_context: updated_context['pi'] = orig_pi
        
        updated_decoder_params = None
        if update_decoder:
            best_decoder_snapshot = decoder_param_snapshots[best_idx]
            
            updated_decoder_params = {}
            original_params = dict(self.decoder.named_parameters())
            
            for name, orig_param in original_params.items():
                if name in best_decoder_snapshot and name in original_decoder_state:
                    delta = best_decoder_snapshot[name] - original_decoder_state[name]
                    updated_decoder_params[name] = orig_param + delta.detach()
                else:
                    updated_decoder_params[name] = orig_param
        
        return updated_context, updated_decoder_params
    
    def generate_output(
        self,
        pairs: torch.Tensor,
        input_seq: torch.Tensor,
        test_pair: Optional[torch.Tensor] = None,
        true_output_seq: Optional[torch.Tensor] = None,
        dropout_eval: bool = True,
        mode: Literal["mean", "first", "gradient_ascent"] = "mean",
        num_steps: int = 10,
        lr: float = 0.01,
        update_decoder: bool = False,
    ) -> Tuple[torch.Tensor, Dict[str, Any], Dict[str, torch.Tensor]]:
        """
        Autoregressively generate a 1D sequence of tokens (pixels) of shape (Batch, Seq_len=48).
        
        Args:
            pairs: Context input-output pairs
            input_seq: Input sequence to generate from
            true_output_seq: Optional ground truth for evaluation
            dropout_eval: Whether to use dropout during evaluation
            mode: How to merge context representations
            num_steps: Number of gradient steps for gradient ascent mode
            lr: Learning rate for gradient ascent mode
            
        Returns:
            output_seq: Generated output sequence
            context_dict: Dictionary containing the context used
            eval_metrics: Evaluation metrics if true_output_seq is provided
        """
        nvib_outputs, _ = self._encode_pairs(pairs, dropout_eval)
        updated_decoder_params = None
        if mode == "mean":
            context_dicts = nvib_outputs
            context = self.merge_nvib_outputs(context_dicts)
        elif mode == "first":
            context_dicts = [nvib_outputs[0]]
            context = context_dicts[0]
        elif mode == "gradient_ascent":
            context, updated_decoder_params = self._get_gradient_ascent_context(
                nvib_outputs, num_steps=num_steps, lr=lr, context_pairs=pairs, gaussian_kl_coeff=0, dirichlet_kl_coeff=0, update_decoder=update_decoder
            )
        else:
            raise ValueError(f"Unsupported generation mode: {mode}")
        
        with torch.no_grad():
            batch_size = input_seq.shape[0]

            input_seq, true_output_seq = self._prepare_sequences_for_decoder(test_pair, dropout_eval=False)
            seq_len = true_output_seq.shape[1]
            
            output_seq = torch.zeros_like(true_output_seq)
            
            for i in range(seq_len):
                if updated_decoder_params is not None:
                    logits = torch.nn.utils.stateless.functional_call(
                        self.decoder, 
                        updated_decoder_params, 
                        (input_seq, output_seq, context)
                    )
                else:
                    logits = self.decoder(
                        input_seq, output_seq, context
                    )
                
                next_token = torch.argmax(logits[:, i, :], dim=-1)
                output_seq[:, i] = next_token
            
            eval_metrics = {}
            if true_output_seq is not None:
                loss, metrics = self._compute_loss(logits, true_output_seq)
                eval_metrics = metrics
        
        return output_seq, {"context": context}, eval_metrics