import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import logging
from typing import Dict, List, Optional, Tuple, Union, Any
from functools import partial

SoftTarget = Dict[str, torch.Tensor]  
HierarchicalTargets = List[Union[torch.Tensor, SoftTarget]]  

from .genar_transformer import (
    GeneAdaLNSelfAttn, 
    GeneAdaLNBeforeHead, 
    ConditionProcessor,
    DropPath
)
from .film_layer import FiLMLayer
from .gene_identity_pooling import GeneIdentityPooling

logger = logging.getLogger(__name__)


class GeneGroupUpsampling(nn.Module):
    
    
    def __init__(self, embed_dim: int, scale_dims: Tuple[int, ...], num_genes: int = 200):
        super().__init__()
        self.embed_dim = embed_dim
        self.scale_dims = scale_dims
        self.num_genes = num_genes
        
        self.group_mappings = self._compute_group_mappings()
        
        self.upsample_transforms = nn.ModuleDict()
        for i in range(len(scale_dims) - 1):
            self.upsample_transforms[f'scale_{i}_to_{i+1}'] = nn.Sequential(
                nn.Linear(embed_dim, embed_dim * 2),
                nn.GELU(),
                nn.Linear(embed_dim * 2, embed_dim),
                nn.LayerNorm(embed_dim),
                nn.Dropout(0.1)
            )
        
        logger.info(f"🧬 Gene Group Upsampling initialized:")
        logger.info(f"   Scale dims: {scale_dims}")
        logger.info(f"   Number of upsampling transforms: {len(self.upsample_transforms)}")
        for key in self.group_mappings:
            logger.info(f"   {key}: {len(self.group_mappings[key])} mappings")
    
    def _compute_group_mappings(self):
        
        mappings = {}
        
        for i in range(len(self.scale_dims) - 1):
            source_dim = self.scale_dims[i]
            target_dim = self.scale_dims[i + 1]
            
            genes_per_source = self.num_genes // source_dim
            genes_per_target = self.num_genes // target_dim
            targets_per_source = genes_per_source // genes_per_target
            
            mapping = []
            for source_idx in range(source_dim):
                start_target = source_idx * targets_per_source
                end_target = start_target + targets_per_source
                target_indices = list(range(start_target, min(end_target, target_dim)))
                mapping.append(target_indices)
            
            mappings[f'scale_{i}_to_{i+1}'] = mapping
            
        return mappings
    
    def forward(self, source_embeddings: torch.Tensor, source_scale_idx: int, target_scale_idx: int):
        
        if source_embeddings is None:
            target_dim = self.scale_dims[target_scale_idx]
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            return torch.zeros(1, target_dim, self.embed_dim, device=device)
        
        B, source_dim, embed_dim = source_embeddings.shape
        target_dim = self.scale_dims[target_scale_idx]
        
        mapping_key = f'scale_{source_scale_idx}_to_{target_scale_idx}'
        if mapping_key not in self.group_mappings:
            return self._interpolate_upsample(source_embeddings, target_dim)
        
        mapping = self.group_mappings[mapping_key]
        
        upsampled = torch.zeros(B, target_dim, embed_dim, device=source_embeddings.device)
        
        for source_idx, target_indices in enumerate(mapping):
            if source_idx < source_dim:
                source_emb = source_embeddings[:, source_idx, :]  # [B, embed_dim]
                
                if mapping_key in self.upsample_transforms:
                    transformed_emb = self.upsample_transforms[mapping_key](source_emb)
                else:
                    transformed_emb = source_emb
                
                for target_idx in target_indices:
                    if target_idx < target_dim:
                        upsampled[:, target_idx, :] = transformed_emb
        
        return upsampled
    
    def _interpolate_upsample(self, source_embeddings, target_dim):
        
        _, source_dim, _ = source_embeddings.shape
        
        if source_dim == 1:
            return source_embeddings.expand(-1, target_dim, -1)
        
        source_embeddings_t = source_embeddings.transpose(1, 2)  
        upsampled = F.interpolate(source_embeddings_t, size=target_dim, mode='linear', align_corners=False)
        return upsampled.transpose(1, 2)  


class MultiScaleGeneVAR(nn.Module):    
    def __init__(
        self,
        vocab_size: int,
        num_genes: int = 200,
        scale_dims: Tuple[int, ...] = (1, 4, 8, 40, 100, 200),
        
        embed_dim: int = 768,
        num_heads: int = 12,
        num_layers: int = 12,
        mlp_ratio: float = 4.0,
        
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.1,
        
        histology_feature_dim: int = 1024,
        spatial_coord_dim: int = 2,
        condition_embed_dim: int = 768,
        cond_drop_rate: float = 0.1,
        
        norm_eps: float = 1e-6,
        shared_aln: bool = False,
        attn_l2_norm: bool = True,
        device: str = 'cuda',
        adaptive_sigma_alpha: float = 0.1,  
        adaptive_sigma_beta: float = 1.0   
    ):
        super().__init__()
        
        self.num_genes = num_genes
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.cond_drop_rate = cond_drop_rate
        self.device = device
        self.adaptive_sigma_alpha = adaptive_sigma_alpha  
        self.adaptive_sigma_beta = adaptive_sigma_beta
        
        self.scale_dims = scale_dims
        self.num_scales = len(scale_dims)
        
        self.histology_feature_dim = histology_feature_dim
        self.spatial_coord_dim = spatial_coord_dim
        self.condition_embed_dim = condition_embed_dim
        
        self.condition_processor = ConditionProcessor(
            histology_dim=histology_feature_dim,
            spatial_dim=spatial_coord_dim,
            condition_embed_dim=condition_embed_dim
        )
        
        self.gene_embedding = nn.Embedding(vocab_size, embed_dim)
        
        self.gene_identity_embedding = nn.Embedding(num_genes, embed_dim)
        
        self.film_layer = FiLMLayer(
            condition_dim=embed_dim,
            feature_dim=embed_dim,
            hidden_dim=embed_dim // 2
        )
        
        self.gene_identity_pooling = GeneIdentityPooling(
            num_genes=num_genes,
            scale_dims=scale_dims,
            embed_dim=embed_dim,
            enable_pooling=True  
        )
        
        self.gene_upsampling = GeneGroupUpsampling(
            embed_dim=embed_dim,
            scale_dims=scale_dims,
            num_genes=num_genes
        )
        
        self.hierarchical_pos_embedding = nn.ModuleDict()
        for i, dim in enumerate(self.scale_dims):
            if dim == self.num_genes:
                max_cumulative_length = 1 + sum(self.scale_dims[:i]) + self.num_genes
            else:
                max_cumulative_length = 1 + sum(self.scale_dims[:i+1])
            self.hierarchical_pos_embedding[f'scale_{i}'] = nn.Embedding(max_cumulative_length, embed_dim)
        
        for i, dim in enumerate(self.scale_dims):
            if dim == self.num_genes:
                max_length = 1 + sum(self.scale_dims[:i]) + self.num_genes
            else:
                max_length = 1 + sum(self.scale_dims[:i+1])
        
    
        self.scale_embedding = nn.Embedding(self.num_scales, embed_dim)
        
        self.start_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)
        
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
        self.transformer_blocks = nn.ModuleList([
            GeneAdaLNSelfAttn(
                block_idx=i,
                embed_dim=embed_dim,
                condition_dim=condition_embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                drop_rate=drop_rate,
                attn_drop_rate=attn_drop_rate,
                drop_path_rate=dpr[i],
                norm_eps=norm_eps,
                shared_aln=shared_aln,
                attn_l2_norm=attn_l2_norm,
            )
            for i in range(num_layers)
        ])
        
        self.head_norm = GeneAdaLNBeforeHead(embed_dim, condition_embed_dim, norm_eps)
        self.output_head = nn.Linear(embed_dim, vocab_size)
        
        self.init_weights()
        
        total_params = self._count_parameters()
        identity_params = self.gene_identity_embedding.num_embeddings * self.gene_identity_embedding.embedding_dim
        film_params = sum(p.numel() for p in self.film_layer.parameters())
        

    
    def _count_parameters(self) -> int:
        """Count the number of trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def _get_hierarchical_position_embedding(self, scale_idx: int, seq_len: int, device: torch.device) -> torch.Tensor:

        embedding_layer = self.hierarchical_pos_embedding[f'scale_{scale_idx}']
        
        pos_indices = torch.arange(seq_len, device=device)
        
        pos_embed = embedding_layer(pos_indices)
        
        return pos_embed.unsqueeze(0)
    
    def init_weights(self, init_std: float = 0.02):
        
        def _init_weights(module):
            if isinstance(module, nn.Linear):
                nn.init.trunc_normal_(module.weight, std=init_std)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.trunc_normal_(module.weight, std=init_std)
            elif isinstance(module, nn.LayerNorm):
                if hasattr(module, 'bias') and module.bias is not None:
                    nn.init.zeros_(module.bias)
                if hasattr(module, 'weight') and module.weight is not None:
                    nn.init.ones_(module.weight)
        
        self.apply(_init_weights)

    def _create_hierarchical_targets(self, target_genes: torch.Tensor) -> HierarchicalTargets:

        hierarchical_targets = []

        target_genes_float = target_genes.float().unsqueeze(1) # -> [B, 1, 200]

        for _, dim in enumerate(self.scale_dims):
            if dim == self.num_genes:
                hard_targets = target_genes.long()
                hard_targets = torch.clamp(hard_targets, 0, self.vocab_size - 1)
                hierarchical_targets.append(hard_targets)
            else:
                pooled_targets = F.adaptive_avg_pool1d(target_genes_float, output_size=dim)
                pooled_targets = pooled_targets.squeeze(1) # -> [B, dim]
                
                floor_targets = torch.floor(pooled_targets).long()
                ceil_targets = torch.ceil(pooled_targets).long()
                weights = pooled_targets - floor_targets.float()  
                
                floor_targets = torch.clamp(floor_targets, 0, self.vocab_size - 1)
                ceil_targets = torch.clamp(ceil_targets, 0, self.vocab_size - 1)
                
                boundary_mask = ceil_targets >= self.vocab_size
                if boundary_mask.any():
                    weights = torch.where(boundary_mask, torch.zeros_like(weights), weights)
                    ceil_targets = torch.where(boundary_mask, floor_targets, ceil_targets)
                
                soft_target = {
                    'floor_targets': floor_targets,
                    'ceil_targets': ceil_targets,
                    'weights': weights
                }
                
                hierarchical_targets.append(soft_target)
            
        return hierarchical_targets

    def _create_gaussian_target_distribution(self, target: torch.Tensor, device: torch.device) -> torch.Tensor:

        vocab_size = self.vocab_size
        
        vocab_indices = torch.arange(vocab_size, device=device, dtype=torch.float32)
        
        mu = target.float().unsqueeze(-1)  # [B, seq_len, 1]

        sigma = self.adaptive_sigma_alpha * mu + self.adaptive_sigma_beta
        
        sigma = torch.clamp(sigma, min=1e-6, max=100.0)  
        
        x = vocab_indices.view(1, 1, -1)  # [1, 1, vocab_size]
        
        squared_diff = (x - mu) ** 2
        gaussian_unnormalized = torch.exp(-squared_diff / (2 * sigma ** 2))
        
        gaussian_unnormalized = torch.clamp(gaussian_unnormalized, min=1e-10, max=1.0)
        
        normalization_factor = gaussian_unnormalized.sum(dim=-1, keepdim=True)
        normalization_factor = torch.clamp(normalization_factor, min=1e-10)
        
        target_dist = gaussian_unnormalized / normalization_factor
        
        if torch.isnan(target_dist).any() or torch.isinf(target_dist).any():
            target_dist = torch.zeros_like(gaussian_unnormalized)
            target_indices = target.long().unsqueeze(-1)  
            target_dist.scatter_(-1, target_indices, 1.0)
        
        return target_dist

    def _compute_soft_label_loss(self, logits: torch.Tensor, target: Union[torch.Tensor, SoftTarget]) -> torch.Tensor:

        if isinstance(target, torch.Tensor):
            target_dist = self._create_gaussian_target_distribution(target, logits.device)
            
            log_probs = F.log_softmax(logits, dim=-1)  
            
            if torch.isinf(log_probs).any():
                log_probs = torch.clamp(log_probs, min=-50.0, max=50.0)
            
            kl_loss = F.kl_div(log_probs, target_dist, reduction='batchmean', log_target=False)
            
            if torch.isnan(kl_loss) or torch.isinf(kl_loss):
                return F.cross_entropy(logits.reshape(-1, self.vocab_size), target.reshape(-1))
            
            return kl_loss
        
        if not isinstance(target, dict) or 'floor_targets' not in target:
            raise ValueError("Soft target must be a dict containing 'floor_targets', 'ceil_targets', 'weights'")
        
        floor_targets = target['floor_targets']  # [B, seq_len]
        ceil_targets = target['ceil_targets']    # [B, seq_len]
        weights = target['weights']              # [B, seq_len], interpolation weights
        
        B, seq_len, _ = logits.shape
        target_seq_len = floor_targets.shape[1]  
        
        if seq_len != target_seq_len:
            raise ValueError(f"Logits seq_len ({seq_len}) doesn't match target seq_len ({target_seq_len})")
        
        log_probs = F.log_softmax(logits, dim=-1)  
        
        target_dist = torch.zeros_like(log_probs)  # [B, seq_len, vocab_size]
        
        batch_indices = torch.arange(B, device=logits.device).unsqueeze(1).expand(-1, target_seq_len)
        seq_indices = torch.arange(target_seq_len, device=logits.device).unsqueeze(0).expand(B, -1)
        
        floor_probs = 1.0 - weights  # [B, target_seq_len]
        target_dist[batch_indices, seq_indices, floor_targets] = floor_probs
        
        ceil_mask = (ceil_targets != floor_targets)
        if ceil_mask.any():
            ceil_probs = weights * ceil_mask.float()  # [B, target_seq_len]
            target_dist[batch_indices, seq_indices, ceil_targets] = ceil_probs
        
        eps = 1e-8
        target_dist = target_dist + eps
        target_dist = target_dist / target_dist.sum(dim=-1, keepdim=True)  
        
        if torch.isnan(target_dist).any():
            return F.cross_entropy(logits.reshape(-1, self.vocab_size), floor_targets.reshape(-1))
        
        if torch.isinf(log_probs).any():
            log_probs = torch.clamp(log_probs, min=-50.0, max=50.0)
        
        kl_loss = F.kl_div(log_probs, target_dist, reduction='batchmean', log_target=False)
        
        if torch.isnan(kl_loss) or torch.isinf(kl_loss):
            return F.cross_entropy(logits.reshape(-1, self.vocab_size), floor_targets.reshape(-1))
        
        return kl_loss

    def forward(
        self,
        histology_features: torch.Tensor,   # [B, 1024]
        spatial_coords: torch.Tensor,       # [B, 2]
        target_genes: Optional[torch.Tensor] = None,  # [B, 200] for training
        top_k: Optional[int] = None
    ) -> Dict[str, torch.Tensor]:

        condition_embed = self.condition_processor(histology_features, spatial_coords)
        
        if self.training:
            if target_genes is None:
                raise ValueError("target_genes must be provided during training.")
            return self.forward_training(condition_embed, target_genes)
        else:
            if target_genes is not None:
                return self.forward_training(condition_embed, target_genes)
            else:
                return self.forward_inference(condition_embed, top_k=top_k)
    
    def forward_training(
        self,
        condition_embed: torch.Tensor,      # [B, condition_embed_dim]
        target_genes: torch.Tensor          # [B, num_genes]
    ) -> Dict[str, torch.Tensor]:

        B = condition_embed.shape[0]
        device = condition_embed.device
        
        hierarchical_targets = self._create_hierarchical_targets(target_genes)
        
        scale_embeddings = []  
        total_loss = 0.0
        final_predictions = None
        final_loss = torch.tensor(0.0, device=device) 
        
        for scale_idx, (scale_dim, scale_target) in enumerate(zip(self.scale_dims, hierarchical_targets)):
            if scale_idx == 0:
                x = self.start_token.expand(B, -1, -1) # [B, 1, D]
            else:
                cumulative_input_tokens = []
                
                for prev_idx in range(scale_idx):
                    prev_target = hierarchical_targets[prev_idx]
                    
                    if isinstance(prev_target, dict):
                        floor_targets = prev_target['floor_targets']
                        ceil_targets = prev_target['ceil_targets']
                        weights = prev_target['weights']
                        
                        prev_scale_tokens = torch.where(weights > 0.5, ceil_targets, floor_targets)
                    else:
                        prev_scale_tokens = prev_target
                    
                    cumulative_input_tokens.append(prev_scale_tokens)
                
                all_prev_tokens = torch.cat(cumulative_input_tokens, dim=1)  # [B, cumulative_length]
                
                input_embed = self.gene_embedding(all_prev_tokens) # [B, cumulative_length, D]
                
                start_token_expanded = self.start_token.expand(B, -1, -1) # [B, 1, D]
                x = torch.cat([start_token_expanded, input_embed], dim=1) # [B, 1 + cumulative_length, D]

            if scale_idx == 0:
                target_positions = torch.zeros(B, scale_dim, self.embed_dim, device=device)
            elif scale_dim == self.num_genes:
                prev_embeddings = scale_embeddings[-1]  
                upsampled_positions = self.gene_upsampling(
                    prev_embeddings, 
                    source_scale_idx=scale_idx-1, 
                    target_scale_idx=scale_idx
                )
                identity_embeddings = self.gene_identity_embedding.weight.unsqueeze(0).expand(B, -1, -1)
                target_positions = 0.7 * upsampled_positions + 0.3 * identity_embeddings
            else:
                prev_embeddings = scale_embeddings[-1]  
                target_positions = self.gene_upsampling(
                    prev_embeddings,
                    source_scale_idx=scale_idx-1,
                    target_scale_idx=scale_idx
                )
            
            x = torch.cat([x, target_positions], dim=1)  # [B, cumulative_length + scale_dim, D]
            
            current_seq_len = x.shape[1]
            scale_embed = self.scale_embedding(torch.tensor([scale_idx], device=device)).view(1, 1, -1)
            pos_embed = self._get_hierarchical_position_embedding(scale_idx, current_seq_len, device)
            x = x + pos_embed + scale_embed
            
            causal_mask = torch.triu(torch.ones(current_seq_len, current_seq_len, device=device) * float('-inf'), diagonal=1)

            for block in self.transformer_blocks:
                x = block(x, condition_embed, causal_mask)

            x_for_prediction = x[:, -scale_dim:, :]  # [B, scale_dim, D]
            
            x_for_prediction = self.head_norm(x_for_prediction, condition_embed)
            
            scale_conditions = self.gene_identity_pooling.get_scale_conditions(
                scale_idx=scale_idx,
                batch_size=B,
                gene_identity_embedding=self.gene_identity_embedding,
                device=device
            )
            
            if scale_conditions is not None:
                x_for_prediction = self.film_layer(x_for_prediction, scale_conditions)
                
                if scale_dim == self.num_genes:
                    logger.debug(f"Applied FiLM modulation for {self.num_genes} genes (final scale)")
                else:
                    logger.debug(f"Applied FiLM modulation for scale {scale_idx} (dim={scale_dim})")
            else:
                if scale_dim == self.num_genes:
                    identity_conditions = self.gene_identity_embedding.weight.unsqueeze(0).expand(B, -1, -1)
                    x_for_prediction = self.film_layer(x_for_prediction, identity_conditions)
                    logger.debug(f"Applied FiLM modulation for {self.num_genes} genes (final scale - fallback)")
            
            logits = self.output_head(x_for_prediction) # Shape: [B, scale_dim, vocab_size]

            logits_for_loss = logits
            
            loss = self._compute_soft_label_loss(logits_for_loss, scale_target)
            total_loss += loss
            
            predicted_tokens = torch.argmax(logits_for_loss, dim=-1)  # [B, scale_dim]
            current_scale_embeddings = self.gene_embedding(predicted_tokens)  # [B, scale_dim, embed_dim]
            scale_embeddings.append(current_scale_embeddings)
            
            if scale_dim == self.num_genes:
                final_predictions = predicted_tokens.float()  
                final_loss = loss
        
        return {
            'loss': total_loss / self.num_scales,
            'loss_final': final_loss,
            'predictions': final_predictions.float(),
            'targets': target_genes.float()
        }

    def forward_inference(
        self,
        condition_embed: torch.Tensor,      # [B, condition_embed_dim]
        temperature: float = 1.0,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        seed: Optional[int] = None,
    ) -> Dict[str, torch.Tensor]:

        B = condition_embed.shape[0]
        device = condition_embed.device

        all_generated_scale_tokens = []
        scale_embeddings = []  

        for scale_idx, scale_dim in enumerate(self.scale_dims):
            if scale_idx == 0:
                x = self.start_token.expand(B, -1, -1) # [B, 1, D]
            else:
                all_prev_tokens = torch.cat(all_generated_scale_tokens, dim=1)  # [B, cumulative_length]
                
                input_embed = self.gene_embedding(all_prev_tokens) # [B, cumulative_length, D]
                
                start_token_expanded = self.start_token.expand(B, -1, -1)
                x = torch.cat([start_token_expanded, input_embed], dim=1) # [B, 1 + cumulative_length, D]
            
            if scale_idx == 0:
                target_positions = torch.zeros(B, scale_dim, self.embed_dim, device=device)
            elif scale_dim == self.num_genes:
                prev_embeddings = scale_embeddings[-1]  
                upsampled_positions = self.gene_upsampling(
                    prev_embeddings, 
                    source_scale_idx=scale_idx-1, 
                    target_scale_idx=scale_idx
                )
                identity_embeddings = self.gene_identity_embedding.weight.unsqueeze(0).expand(B, -1, -1)
                target_positions = 0.7 * upsampled_positions + 0.3 * identity_embeddings
            else:
                prev_embeddings = scale_embeddings[-1]  
                target_positions = self.gene_upsampling(
                    prev_embeddings,
                    source_scale_idx=scale_idx-1,
                    target_scale_idx=scale_idx
                )
            
            x = torch.cat([x, target_positions], dim=1)  # [B, cumulative_length + scale_dim, D]
            
            current_seq_len = x.shape[1]
            scale_embed = self.scale_embedding(torch.tensor([scale_idx], device=device)).view(1, 1, -1)
            pos_embed = self._get_hierarchical_position_embedding(scale_idx, current_seq_len, device)
            x = x + pos_embed + scale_embed

            causal_mask = torch.triu(torch.ones(current_seq_len, current_seq_len, device=device) * float('-inf'), diagonal=1)

            for block in self.transformer_blocks:
                x = block(x, condition_embed, causal_mask)
            
            x_for_prediction = x[:, -scale_dim:, :]  # [B, scale_dim, D]
            
            x_for_prediction = self.head_norm(x_for_prediction, condition_embed)
            
            scale_conditions = self.gene_identity_pooling.get_scale_conditions(
                scale_idx=scale_idx,
                batch_size=B,
                gene_identity_embedding=self.gene_identity_embedding,
                device=device
            )
            
            if scale_conditions is not None:
                x_for_prediction = self.film_layer(x_for_prediction, scale_conditions)
            else:
                if scale_dim == self.num_genes:
                    identity_conditions = self.gene_identity_embedding.weight.unsqueeze(0).expand(B, -1, -1)
                    x_for_prediction = self.film_layer(x_for_prediction, identity_conditions)
            
            logits = self.output_head(x_for_prediction) # Shape: [B, scale_dim, vocab_size]
            
            logits = logits / temperature
            
            if top_k is not None and top_k > 0:
                top_k_values, _ = torch.topk(logits, top_k, dim=-1)
                
                kth_value = top_k_values[:, :, -1].unsqueeze(-1)
                
                mask = logits < kth_value
                
                logits[mask] = float('-inf')

            probabilities = F.softmax(logits, dim=-1)

            if seed is not None:
                torch.manual_seed(seed)
            
            probabilities_flat = probabilities.view(-1, self.vocab_size)
            sampled_tokens = torch.multinomial(probabilities_flat, num_samples=1)
            
            current_scale_dim = logits.shape[1]
            sampled_tokens = sampled_tokens.view(B, current_scale_dim)
            
            all_generated_scale_tokens.append(sampled_tokens)
            
            current_scale_embeddings = self.gene_embedding(sampled_tokens)  # [B, scale_dim, embed_dim]
            scale_embeddings.append(current_scale_embeddings)
            
            generated_tokens = sampled_tokens

        return {
            'generated_sequence': generated_tokens.float()
        }

    def inference(
        self,
        histology_features: torch.Tensor,
        spatial_coords: torch.Tensor,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        seed: Optional[int] = None
    ) -> Dict[str, torch.Tensor]:

        
        self.eval()
        with torch.no_grad():
            # Process conditions
            condition_embed = self.condition_processor(histology_features, spatial_coords)
            
            # Generate predictions
            return self.forward_inference(
                condition_embed, 
                temperature, 
                top_k, 
                top_p,
                seed
            )
    
    def save_checkpoint(self, save_path: str, epoch: Optional[int] = None):

        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        
        checkpoint = {
            'model_state_dict': self.state_dict(),
            'num_genes': self.num_genes,
            'scale_dims': self.scale_dims,
            'vocab_size': self.vocab_size,
            'embed_dim': self.embed_dim,
            'num_heads': self.num_heads,
            'num_layers': self.num_layers,
            'histology_feature_dim': self.histology_feature_dim,
            'spatial_coord_dim': self.spatial_coord_dim,
            'condition_embed_dim': self.condition_embed_dim,
            'adaptive_sigma_alpha': self.adaptive_sigma_alpha,
            'adaptive_sigma_beta': self.adaptive_sigma_beta,
            'epoch': epoch
        }
        
        torch.save(checkpoint, save_path)
    
    @classmethod
    def load_checkpoint(cls, ckpt_path: str, device: str = 'cuda') -> 'MultiScaleGeneVAR':

        checkpoint = torch.load(ckpt_path, map_location=device)
        
        # Create model with saved configuration
        model = cls(
            vocab_size=checkpoint['vocab_size'],
            num_genes=checkpoint['num_genes'],
            scale_dims=checkpoint['scale_dims'],
            embed_dim=checkpoint['embed_dim'],
            num_heads=checkpoint['num_heads'],
            num_layers=checkpoint['num_layers'],
            histology_feature_dim=checkpoint['histology_feature_dim'],
            spatial_coord_dim=checkpoint['spatial_coord_dim'],
            condition_embed_dim=checkpoint['condition_embed_dim'],
            adaptive_sigma_alpha=checkpoint.get('adaptive_sigma_alpha', 0.1),  # Default for backward compatibility
            adaptive_sigma_beta=checkpoint.get('adaptive_sigma_beta', 1.0),   # Default for backward compatibility
            device=device
        )
        
        # Load state dict
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        
        return model
    
    def get_model_info(self) -> Dict:

        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        
        condition_params = sum(p.numel() for p in self.condition_processor.parameters())
        transformer_params = sum(p.numel() for p in self.transformer_blocks.parameters())
        embedding_params = self.gene_embedding.weight.numel() + sum(p.numel() for p in self.hierarchical_pos_embedding.parameters()) + self.scale_embedding.weight.numel()
        output_params = sum(p.numel() for p in self.head_norm.parameters()) + sum(p.numel() for p in self.output_head.parameters())
        
        return {
            'total_parameters': total_params,
            'trainable_parameters': trainable_params,
            'condition_processor_parameters': condition_params,
            'transformer_parameters': transformer_params,
            'embedding_parameters': embedding_params,
            'output_parameters': output_params,
            'num_genes': self.num_genes,
            'scale_dims': self.scale_dims,
            'num_scales': self.num_scales,
            'embed_dim': self.embed_dim,
            'num_heads': self.num_heads,
            'num_layers': self.num_layers,
            'vocab_size': self.vocab_size,
            'total_sequence_length': self.num_genes + 1
        }

    def enable_kv_cache(self):

        for block in self.transformer_blocks:
            block.enable_kv_cache(True)
    
    def disable_kv_cache(self):

        for block in self.transformer_blocks:
            block.enable_kv_cache(False)
    
    def enable_multi_scale_gene_modulation(self):

        self.gene_identity_pooling.enable()
    
    def disable_multi_scale_gene_modulation(self):

        self.gene_identity_pooling.disable()

    def _compute_weighted_cross_entropy_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:

        vocab_size = logits.shape[-1]
        
        token_ids = torch.arange(vocab_size, device=logits.device, dtype=torch.float32)  # [vocab_size]
        target_values = targets.float().unsqueeze(1)  # [total_predictions, 1]
        
        distances = torch.abs(token_ids.unsqueeze(0) - target_values)  # [total_predictions, vocab_size]
        
        sigma = vocab_size * 0.1  
        weights = torch.exp(-distances ** 2 / (2 * sigma ** 2))  # [total_predictions, vocab_size]
        
        log_probs = F.log_softmax(logits, dim=-1)  # [total_predictions, vocab_size]
        
        weighted_log_probs = log_probs * weights  # [total_predictions, vocab_size]
        
        target_log_probs = weighted_log_probs.gather(1, targets.unsqueeze(1)).squeeze(1)  # [total_predictions]
        
        loss = -target_log_probs.mean()
        
        return loss

GENAR = MultiScaleGeneVAR