from functools import partial

import numpy as np
from tqdm import tqdm
import scipy.stats as stats
import math
import torch
import torch.nn as nn
import time
from models.diffloss import DiffLoss_v2
from models.ectloss import ECTLoss
from dataclasses import dataclass
from models.vision_streaming import DINOv2StreamingEncoder
from models.waveform_streaming import OnlineCausalWaveformDecoder


#### causal decoder
from modules_2.stable_audio_tools.models import create_model_from_config
from modules_2.stable_audio_tools.models.autoencoders import AudioAutoencoder
from modules_2.stable_audio_tools.models.utils import load_ckpt_state_dict
from modules_2.stable_audio_tools.utils.torch_common import copy_state_dict
######

from typing import Optional, List
from .transformer import StreamingTransformer, AggregateTransformer
from einops import rearrange
from modules.blocks import concat_video_with_delta

torch.backends.cuda.enable_flash_sdp = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True  

@dataclass
class ModelArgs:
    d_model: int = 2048
    num_layers: int = 22
    num_heads: int = 32
    seq_len: int = 300
    per_frame_len: int = 1
    num_kv_heads: Optional[int] = None
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    max_period: float = 10_000
    norm_eps: float = 1e-5
    initializer_range: float = 0.02
    layer_scale: Optional[float] = None
    num_types: int = 2
    input_type: str = 'interleave'
    dtype: Optional[torch.dtype] = None
    device: Optional[torch.device] = None
    rope_scaling: Optional[dict] = None  # {'type': 'linear', 'factor': 2.0}
    
    type_drop_p: float = 0.0
    token_dropout_p: float = 0.0
    attn_dropout_p: float = 0.1
    resid_dropout_p: float = 0.1
    ffn_dropout_p: float = 0.1
    drop_path_rate: float = 0.1
    trans_cfg_dropout_prob: float = 0.1
    head_dropout_p: float = 0.1

    audio_embed_dim: int = 128
    video_embed_dim: int = 128
    
    # diffusion
    head_channel: int = 1024
    num_head_block: int = 3
    label_balance: float = 0.5
    diffusion_batch_mul: int = 6
    naive_mar_mlp: bool = False
    seq_len_per_frame: int = 1
    condition_merge: bool = False
    diffusion_type: str = 'diffusion' # diffusion or trigflow
    
    
    
    noise_augmentation: bool = False
    modal_projection: str = 'linear'
    noise_aug_max: float = 0.5
    label_drop_prob: float = 0.0
    audio_pretraining: bool = False
    norm_type: str = 'rms'
    
    
    vision_aggregation: bool = False
    d_aggregate: int = 1024
    num_heads_aggregate: int = 32
    num_layers_aggregate: int = 2
    num_aggregated_tokens: int = 1
    grid_feature_length: int = 540
    aggregate_trans_architecture: str = 'linear'
    aggregation_method: str = 'self'
    norm_type_agg: str = 'rms'
    spatial_ds_rate: int = 2
    
    P_mean: float = -0.4 
    P_std: float = 1.0
    sigma_data: float = 0.5
    
    # training hyperparams
    no_subtract: bool = False  
    pre_tok_norm: bool = False
    

    # ECT hyperparams
    adj_map_func: str = 'sigmoid'
    ect_q: float = 4.0
    ect_c: float = 0.06
    ect_k: float = 8.0
    ect_b: float = 8.0
    

def _grad_norm(loss, params):
    grads = torch.autograd.grad(
        loss, params, retain_graph=True, create_graph=False, allow_unused=True
    )
    grads = [g for g in grads if g is not None]
    if len(grads) == 0:
        return torch.tensor(0., device=loss.device)
    return torch.sqrt(sum(g.pow(2).sum() for g in grads))


def process_tensor(t: torch.Tensor, per_frame_len: int, diffusion_batch_mul: int = 4) -> torch.Tensor:
    bsz, seq_len, _ = t.shape
    if per_frame_len == 1:
        t = t.reshape(bsz * seq_len, -1)
        return t.repeat(diffusion_batch_mul, 1)
    else:
        assert seq_len % per_frame_len == 0, "seq_len must be divisible by input_len"
        t = t.reshape(bsz, seq_len // per_frame_len, per_frame_len, -1)
        t = t.reshape(bsz * (seq_len // per_frame_len), per_frame_len, -1)
        return t.repeat(diffusion_batch_mul, 1, 1)

def process_mask(m: torch.Tensor, per_frame_len: int, diffusion_batch_mul: int = 4) -> torch.Tensor:
    bsz, seq_len = m.shape
    if per_frame_len == 1:
        m = m.reshape(bsz * seq_len)
        return m.repeat(diffusion_batch_mul)
    else:
        assert seq_len % per_frame_len == 0, "seq_len must be divisible by input_len"
        m = m.reshape(bsz, seq_len // per_frame_len, per_frame_len)
        m = m.reshape(bsz * (seq_len // per_frame_len), per_frame_len)
        return m.repeat(diffusion_batch_mul, 1)


class GNSE(nn.Module):
    """ Generative neural sound engine
    """
    def __init__(self, config: ModelArgs):
        super().__init__()
        self.config = config
        self.input_type = config.input_type
        self.initializer_range = config.initializer_range
        self.modal_projection = config.modal_projection 
        self.per_frame_len = config.per_frame_len
        self.audio_pretraining = config.audio_pretraining
        self.vision_aggregation = config.vision_aggregation
        self.aggregation_method = config.aggregation_method
        self.condition_merge = config.condition_merge
        self.diffusion_type = config.diffusion_type
        self.no_subtract = config.no_subtract
        self.spatial_ds_rate = config.spatial_ds_rate
        self.disperse_loss = config.disperse_loss
        self.disperse_lambda = config.disperse_lambda 
        self.disperse_temperature = config.disperse_temperature
        self.film_fuse = config.film_fuse
        self.grid_delta_fuse = config.grid_delta_fuse
        self.segment_size = config.segment_size
        self.pre_tok_norm = config.pre_tok_norm
        self.enable_rar = config.enable_rar
        
        if config.modal_projection ==  'linear':
            if self.vision_aggregation:
                self.video_proj = nn.Sequential(
                        nn.Linear(config.d_aggregate, config.d_model),
                        nn.SiLU(),
                        nn.Linear(config.d_model, config.d_model),
                    )
                self.video_proj_3d = nn.Sequential(
                    nn.Conv3d(
                        in_channels= 2 * config.video_embed_dim if not config.no_subtract else config.video_embed_dim,
                        out_channels=config.d_aggregate,
                        kernel_size=(1,1,1),
                        bias=True
                    ),
                    nn.SiLU(),
                    nn.Conv3d(
                        in_channels=config.d_aggregate,
                        out_channels=config.d_aggregate,
                        kernel_size=(1,3,3),
                        stride=(1, self.spatial_ds_rate, self.spatial_ds_rate),
                        padding=(0,1,1),
                        bias=False
                    ),
                )
            else:
                self.video_proj = nn.Sequential(
                        nn.Linear(config.video_embed_dim, config.d_model),
                        nn.SiLU(),
                        nn.Linear(config.d_model, config.d_model),
                    )
            self.audio_proj = nn.Sequential(
                    nn.Linear(config.audio_embed_dim, config.d_model),
                    nn.SiLU(),
                    nn.Linear(config.d_model, config.d_model),
                )
        else:
            raise NotImplementedError
        
        if self.vision_aggregation:
            self.aggregate_transformer = AggregateTransformer(
                d_model=config.d_aggregate,
                num_heads=config.num_heads_aggregate,
                num_layers=config.num_layers_aggregate,
                grid_feature_length=config.grid_feature_length,
                num_aggregated_tokens = config.num_aggregated_tokens,
                layer_scale=config.layer_scale,
                ffn_dim_multiplier=config.ffn_dim_multiplier,
                ffn_dropout_p=config.ffn_dropout_p,
                attn_dropout_p=config.attn_dropout_p,
                resid_dropout_p=config.resid_dropout_p,
                multiple_of=config.multiple_of,
                num_kv_heads=config.num_kv_heads,
                initializer_range=config.initializer_range,
                aggregate_trans_architecture=config.aggregate_trans_architecture,
                aggregation_method=config.aggregation_method,
                norm_type_agg=config.norm_type_agg,
            )
            scale = 1.0 / math.sqrt(config.d_aggregate)
            self.num_aggregated_tokens = config.num_aggregated_tokens
            self.aggregated_tokens = nn.Parameter(scale * torch.randn(config.num_aggregated_tokens, config.d_aggregate))
            if self.pre_tok_norm:
                self.pre_tok_norm = nn.RMSNorm(config.d_aggregate, eps=1e-5)
        
        self.transformer = StreamingTransformer(
            d_model=config.d_model, 
            num_heads=config.num_heads, 
            num_layers=config.num_layers,
            seq_len=config.seq_len,
            num_types=config.num_types,
            type_drop_p=config.type_drop_p,
            layer_scale=config.layer_scale,
            drop_path_rate=config.drop_path_rate,
            ffn_dim_multiplier=config.ffn_dim_multiplier,
            ffn_dropout_p=config.ffn_dropout_p,
            multiple_of=config.multiple_of,
            num_kv_heads=config.num_kv_heads,
            attn_dropout_p=config.attn_dropout_p,
            resid_dropout_p=config.resid_dropout_p,
            initializer_range=config.initializer_range,
            token_dropout_p=config.token_dropout_p,
            max_period=config.max_period,
            noise_augmentation=config.noise_augmentation,
            k_max = config.noise_aug_max,
            input_type=config.input_type,
            trans_cfg_dropout_prob=config.trans_cfg_dropout_prob,
            audio_pretraining = config.audio_pretraining,
            num_aggregated_tokens=config.num_aggregated_tokens,
            norm_type=config.norm_type,
            condition_merge=config.condition_merge,
            enable_rar=config.enable_rar,
            rope_scaling=config.rope_scaling,
        )
        
        # Diffusion Loss
        if self.condition_merge:
            z_channels = 2 * config.d_model
        else:
            z_channels = config.d_model
        if config.diffusion_type == 'diffusion_v2':
            self.diffloss = DiffLoss_v2(
                target_channels=config.audio_embed_dim,
                z_channels=z_channels,
                model_channels=config.head_channel,
                num_res_blocks=config.num_head_block,
                seq_len_per_frame=config.seq_len_per_frame,
                P_mean=config.P_mean, 
                P_std=config.P_std, 
                sigma_data=config.sigma_data,
                label_drop_prob=config.label_drop_prob,
                initializer_range=config.initializer_range,
                label_balance=config.label_balance,
                naive_mar_mlp=config.naive_mar_mlp,
                condition_merge=config.condition_merge,
                head_dropout_p=config.head_dropout_p,
            )
        
        elif config.diffusion_type == 'ect':
            self.diffloss = ECTLoss(
                target_channels=config.audio_embed_dim,
                z_channels=z_channels,
                model_channels=config.head_channel,
                num_res_blocks=config.num_head_block,
                seq_len_per_frame=config.seq_len_per_frame,
                P_mean=config.P_mean, 
                P_std=config.P_std, 
                sigma_data=config.sigma_data,
                label_drop_prob=config.label_drop_prob,
                initializer_range=config.initializer_range,
                label_balance=config.label_balance,
                naive_mar_mlp=config.naive_mar_mlp,
                condition_merge=config.condition_merge,
                adj_map_func=config.adj_map_func,
                ect_q=config.ect_q,
                ect_c=config.ect_c,
                ect_k=config.ect_k,
                ect_b=config.ect_b,
            )
        self.diffusion_batch_mul = config.diffusion_batch_mul
        self.initialize_weights()

    def initialize_weights(self):  
        self.video_proj.apply(self._init_linear)
        self.video_proj_3d.apply(self._init_linear)
        self.audio_proj.apply(self._init_linear)

    def _init_linear(self, m: nn.Module):
        if isinstance(m, nn.Linear):
            torch.nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def _init_weights(self, m: nn.Module):
        if isinstance(m, nn.Linear):
            std = self.initializer_range
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            m.weight.data.normal_(mean=0.0, std=std)
            if m.bias is not None:
                m.bias.data.zero_()

    def forward_loss(self, z: torch.Tensor, target: torch.Tensor):
        target = process_tensor(target, self.per_frame_len, self.diffusion_batch_mul)
        z = process_tensor(z, self.per_frame_len, self.diffusion_batch_mul)

        loss = self.diffloss(z=z, target=target)
        return loss

    def forward(
        self, 
        video: Optional[torch.Tensor] = None, 
        audio: Optional[torch.Tensor] = None, 
        target: Optional[torch.Tensor] = None, 
        global_step: int = None,
        total_steps: int = None, 
        perm_ratio: Optional[float] = None,
        # mask: Optional[torch.Tensor] = None
    ):  
        if self.audio_pretraining:
            assert audio is not None, "audio must be provided in audio_pretraining mode"
            # projection
            audio = self.audio_proj(audio)
            video = None
            
        else:
            assert video is not None and audio is not None, "Both video and audio must be provided"
            assert video.shape[0] == audio.shape[0]

            disperse_loss = torch.tensor(0.0, device=video.device)
            if self.vision_aggregation:
                if not self.no_subtract:
                    video = concat_video_with_delta(video)
                video = video.permute(0, 4, 1, 2, 3).contiguous()  # [B, 2D, T, H, W]
                video = self.video_proj_3d(video)
                
                
                B, D, T, H2, W2 = video.shape

                video = rearrange(video, 'b c t h w -> (b t) (h w) c')
                if self.pre_tok_norm:
                    video = self.pre_tok_norm(video)
                # #############
                
                #########
                video = self.aggregate_transformer(
                    video, 
                    self.aggregated_tokens.unsqueeze(0).expand(B * T, -1, -1)
                ) # [batch, # of aggregated tokens, d_aggregate]
                ########
            
                video = video.view(B, T, -1, D)  # [B, T, num_agg, d_agg]
                if video.shape[2] == 1:
                    video = video.squeeze(2)  # [B, T, d_agg]
                else:
                    video = video.view(B, T * video.shape[2], D)
            video = self.video_proj(video)
            audio = self.audio_proj(audio)
        
        # transformer decoder
        z = self.transformer(
            video, audio,
            global_step=global_step,
            total_steps=total_steps,
            perm_ratio=perm_ratio
        )
        
        if self.condition_merge:
            v_tokens = z[:, ::2, :]   
            z_tokens = z[:, 1::2, :] 
            z = torch.cat([v_tokens, z_tokens], dim=-1)
        else:
            z = z
        # diffloss
        weighted_dsm_loss, dsm_loss, logvar = self.forward_loss(z=z, target=target)
        total_loss = weighted_dsm_loss 
        
        
        return total_loss, weighted_dsm_loss, dsm_loss, logvar, _
    
    def _encode_video(self, video):
        if self.vision_aggregation:
            if not self.no_subtract:
                video = concat_video_with_delta(video)
            video = video.permute(0, 4, 1, 2, 3).contiguous()
            video = self.video_proj_3d(video)
            B, D, T, H2, W2 = video.shape
            video = rearrange(video, 'b c t h w -> (b t) (h w) c')
            video = self.aggregate_transformer(
                video, 
                self.aggregated_tokens.unsqueeze(0).expand(B * T, -1, -1)
            ) # [batch, # of aggregated tokens, d_aggregate]
            video = video.view(B, T, -1, D)  # [B, T, num_agg, d_agg]
            if video.shape[2] == 1:
                video = video.squeeze(2)  # [B, T, d_agg]
            else:
                video = video.view(B, T * video.shape[2], D)
        video = self.video_proj(video)
        return video

    def transformer_sampling(
        self, 
        input_token: torch.Tensor,
        type_id: int, 
        input_pos: torch.Tensor, 
        trans_cfg_scale: float = 1.0,
        inference_noise: float = 0.2,
        audio_pretraining: bool = False,
    ):
        # projection
        if type_id == 0:
            input_token = self.video_proj(input_token)
        elif type_id == 1:
            input_token = self.audio_proj(input_token)
            noise = torch.randn_like(input_token)
            input_token = inference_noise * noise + (1 - inference_noise) * input_token
        
        if trans_cfg_scale > 1.0 and not audio_pretraining:
            z = self.transformer.inference(input_token, type_id, input_pos, trans_cfg_scale=trans_cfg_scale)
            cond_z, uncond_z = torch.split(z, len(z) // 2, dim=0)
            z = uncond_z + (cond_z - uncond_z) * trans_cfg_scale

        else:
            z = self.transformer.inference(input_token, type_id, input_pos, audio_pretraining=audio_pretraining)
        return z
    
    
    def bos_sampling(
        self, 
        input_token: torch.Tensor,
        type_id: int, 
        input_pos: torch.Tensor, 
        trans_cfg_scale: float = 1.0,
        inference_noise: float = 0.2,
        audio_pretraining: bool = False,
    ):
        assert type_id == int(1)
        input_token = input_token
        noise = torch.randn_like(input_token)
        input_token = inference_noise * noise + (1 - inference_noise) * input_token
        z = self.transformer.inference(input_token, type_id, input_pos)
        
        if trans_cfg_scale > 1.0 and not audio_pretraining:
            cond_z, uncond_z = torch.split(z, len(z) // 2, dim=0)
            z = uncond_z + (cond_z - uncond_z) * trans_cfg_scale
        
        return z
    
    
    
    def diffusion_sampling(
        self, 
        z: torch.Tensor, 
        cfg: float = 1.0, 
        cfg_schedule: str = "linear", 
        device: torch.device = 'cuda', 
        temperature: float = 1.0, 
        net_autoguidance: nn.Module = None,
        **edm_kwargs
    ):

        sampled_token_latent = self.diffloss.sample(
            z, 
            cfg=cfg,
            temperature=temperature, 
            device=device, 
            net_autoguidance=net_autoguidance, 
            **edm_kwargs
        )
        tokens = sampled_token_latent.clone()
        return tokens
    
    def decode_one_token(
        self, 
        input_token: torch.Tensor, 
        input_pos: torch.Tensor, 
        type_id: int,
        diff_cfg_scale: float = 1.0,
        cfg_schedule: str = "linear", 
        trans_cfg_flag: bool = True, 
        trans_cfg_scale: float = 1.0, 
        temperature: float = 1.0,
        net_autoguidance: nn.Module = None,
        inference_noise: float = 0.2,
        audio_pretraining: bool = False,
        **edm_kwargs, 
    ):
        assert input_pos.shape[-1] == 1
        if type_id == int(0):
            input_token = self.video_proj(input_token)
        elif type_id == int(1):
            input_token = self.audio_proj(input_token)
            noise = torch.randn_like(input_token)
            input_token = inference_noise * noise + (1 - inference_noise) * input_token
        
        if trans_cfg_scale > 1.0 and not audio_pretraining:
            z = self.transformer.inference(input_token, type_id, input_pos, trans_cfg_scale=trans_cfg_scale)
            cond_z, uncond_z = torch.split(z, len(z) // 2, dim=0)
            z_token = uncond_z + (cond_z - uncond_z) * trans_cfg_scale
            # if trans_cfg_flag:
            #     z = uncond_z + (cond_z - uncond_z) * trans_cfg_scale
            # else:
            #     z = cond_z
        else:
            z_token = self.transformer.inference(input_token, type_id, input_pos, audio_pretraining=audio_pretraining)
        if type_id == int(0):
            next_token = self.diffusion_sampling(
            z_token, 
            cfg=diff_cfg_scale, 
            cfg_schedule=cfg_schedule,
            temperature=temperature,
            device=z_token.device, 
            net_autoguidance=net_autoguidance,
            **edm_kwargs
            )
        else:
            next_token = None
        return next_token

    def decode_n_tokens_condition_merge(
        self, 
        cond_combined: torch.Tensor, 
        cur_token: torch.Tensor, 
        input_pos: torch.Tensor, 
        num_new_tokens: int,
        cfg_interval: int = -1, 
        diff_cfg_scale: float = 1.0, 
        cfg_schedule: str = "linear", 
        temperature: float = 1.0,  
        trans_cfg_scale: float = 1.0, 
        net_autoguidance: nn.Module = None,
        inference_noise: float = 0.2,
        audio_pretraining: bool = False,
        **edm_kwargs, 
    ):
        new_audio_tokens = []
        bs = cur_token.size(0)
        num_new_tokens = int(num_new_tokens//2)
        
        if not torch.is_tensor(input_pos):
            input_pos = torch.tensor([int(input_pos)], device=cur_token.device, dtype=torch.int)
        else:
            input_pos = input_pos.to(device=cur_token.device, dtype=torch.int)
            if input_pos.dim() == 0:
                input_pos = input_pos.view(1)
        
        for i in tqdm(range(num_new_tokens), desc="transformer sampling...."):
            # audio token
            z_out = self.transformer_sampling(
                cur_token, type_id=int(1), 
                input_pos=input_pos,
                trans_cfg_scale=trans_cfg_scale, 
                inference_noise=inference_noise,
                audio_pretraining=audio_pretraining,
            )
            input_pos += 1
            
            # video-ish token
            input_token = cond_combined[:, int(i + 1)].view(bs, 1, -1)
            v_out = self.transformer_sampling(
                input_token, type_id=int(0), 
                input_pos=input_pos,
                trans_cfg_scale=trans_cfg_scale, 
                inference_noise=inference_noise,
                audio_pretraining=audio_pretraining,
            )
            input_pos += 1
            
            z_token = torch.cat([z_out, v_out], dim=-1)
            next_token = self.diffusion_sampling(
                z_token, 
                cfg=diff_cfg_scale, 
                cfg_schedule=cfg_schedule,
                temperature=temperature,
                device=z_token.device, 
                net_autoguidance=net_autoguidance,
                **edm_kwargs
                )
            cur_token = next_token.view(bs, 1, -1)
            new_audio_tokens.append(cur_token.clone())
        return new_audio_tokens

    
    @torch.no_grad()
    def offline_sample_tokens_condition_merge_pca(
        self, 
        cond: torch.Tensor, 
        max_new_tokens: int, 
        context_mode: str = "none", # ("pi", "ntk", "slideing", "none")
        vision_aggregation: bool = True,
        audio_pretraining: bool = False,
        trans_cfg_scale: float = 1.0, 
        cfg_interval: int = -1, 
        cfg_schedule: str = "linear", 
        diff_cfg_scale: float = 1.0, 
        temperature: float = 1.0, 
        net_autoguidance: nn.Module = None,
        inference_noise: float = 0.2,
        **edm_kwargs
    ):  
        assert max_new_tokens == int(cond.shape[1] * 2 -1)
        device = cond.device
        
        if vision_aggregation:
            if not self.no_subtract:
                video = concat_video_with_delta(cond) # [batch, time, height, width, dimension * 2]
            else:
                video = cond
            video = video.permute(0, 4, 1, 2, 3).contiguous()  # [B, 2D, T, H, W]
            video = self.video_proj_3d(video)
            B, D, T, H2, W2 = video.shape
            video = rearrange(video, 'b c t h w -> (b t) (h w) c')
        
            video = self.aggregate_transformer(
                video, 
                self.aggregated_tokens.unsqueeze(0).expand(B * T, -1, -1)
            ) # [batch, # of aggregated tokens, d_aggregate]

            video = video.view(B, T, -1, D)  # [B, T, num_agg, D]
            if video.shape[2] == 1:
                cond_combined = video.squeeze(2)  # [B, T, D]
            else:
                cond_combined = video.view(B, T * video.shape[2], D)
            if self.film_fuse:
                delta_tokens = cond_combined - torch.roll(cond_combined, 1, 1)
                delta_tokens[:,0].zero_()
                cond_combined = self.video_film_fuse(cond_combined, delta_tokens)
        else:
            cond_combined = cond

        train_seq_len = int(self.transformer.seq_len)
        
        T = 1
        T_new = T + max_new_tokens
        
        if context_mode in ("pi", "ntk"):
            scale = max(1.0, float(T_new) / float(train_seq_len))
            self.transformer.set_context_extension(context_mode, factor=scale)
            max_seq_length = T_new 
        elif context_mode == "sliding":
            window_size = train_seq_len
            self.transformer.set_context_extension("sliding", window_size=window_size)
            max_seq_length = T_new 
        else:
            self.transformer.set_context_extension("none")
            max_seq_length = T_new
        
        max_batch_size = cond.shape[0]
        max_batch_size_cfg = max_batch_size * 2 if trans_cfg_scale > 1.0 else max_batch_size
        infer_dtype = self.video_proj[0].weight.dtype
        with torch.device(device):
            self.transformer.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=infer_dtype)

        # create an empty tensor of the expected final shape and fill in the current tokens
        seq = torch.empty((max_batch_size, int(T_new/2), self.config.audio_embed_dim), dtype=torch.float, device=device)

        # input_pos = torch.arange(0, T+1, device=device)
        
        ### init token decoding
        pos = torch.tensor([0], device=device, dtype=torch.int)
        
        z = self.bos_sampling(
            self.transformer.init_audio_token.expand(max_batch_size_cfg, -1), 
            type_id=int(1), 
            input_pos=pos,
            trans_cfg_scale=trans_cfg_scale,
            inference_noise=inference_noise,
            audio_pretraining=audio_pretraining,
        )
        # transformer sampling
        pos = pos + 1
        v_init_token = self.transformer_sampling(
            cond_combined[:, 0], 
            type_id=int(0), 
            input_pos=pos,
            trans_cfg_scale=trans_cfg_scale, 
            inference_noise=inference_noise,
            audio_pretraining=audio_pretraining,
        )

        z = torch.cat([z, v_init_token], dim=-1)
        
        # diffusion sampling
        next_token = self.diffusion_sampling(
            z,
            cfg=diff_cfg_scale, 
            cfg_schedule=cfg_schedule,
            temperature=temperature,
            device=z.device,
            net_autoguidance=net_autoguidance,
            **edm_kwargs
        )

        seq[:, 0] = next_token
        
        ### n_token decoding
        pos = pos + 1
        generated_tokens = self.decode_n_tokens_condition_merge(
            cond_combined, 
            next_token, 
            pos, 
            max_new_tokens-1, 
            cfg_interval, 
            diff_cfg_scale, 
            cfg_schedule, 
            temperature,
            trans_cfg_scale,
            net_autoguidance=net_autoguidance,
            inference_noise=inference_noise,
            audio_pretraining=audio_pretraining,
            **edm_kwargs
        )
        seq[:, 1:] = torch.cat(generated_tokens, dim=1)
        return seq
    
    @torch.no_grad()
    def online_sample_tokens_from_video(
        self,
        frames: torch.Tensor, 
        dino: "DINOv2StreamingEncoder",   
        *,
        stage1_model: "AudioAutoencoder",
        device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        z_mean: torch.Tensor, 
        z_std: torch.Tensor,
        noncausal_right_margin_latents: int = 0,
        context_mode: str = "none", # {"pi","ntk","sliding","none"}
        trans_cfg_scale: float = 1.0,
        diff_cfg_scale: float = 1.0,
        cfg_schedule: str = "constant",
        temperature: float = 1.0,
        inference_noise: float = 0.2,
        audio_pretraining: bool = False,
        vision_aggregation: bool = True,
        measure_latency: bool = True,
        cfg_interval: int = -1, 
        net_autoguidance: nn.Module = None,
        **edm_kwargs
    ):
        """
        Online pipeline:
        For each frame t:
            frame -> DINO grid -> (optional downsample) -> PCA -> Δ(grid)  (t=0: zero)
                -> Concat([grid, Δ(grid)]) -> 3D Conv (T=1) -> flatten -> Aggregator
                -> cond_token_t
        Interleave with AR decoding:
            BOS(audio) -> V(cond_t0) -> diffusion -> A_0
            then for t>=1:  [A_{t-1}] -> V(cond_t) -> diffusion -> A_t
        Returns:
        seq: [B, T, audio_embed_dim], generated audio tokens
        (optional) latency dict in ms
        """
        # device = frames.device
        assert frames.ndim == 5 and frames.shape[2] == 3, "frames must be [B,T,3,H,W]"
        assert self.vision_aggregation is True, "only vision_aggregation==True."
        assert getattr(self, "film_fuse", False) is False, "film_fuse is not supported here."
        
        B, T, _, H, W = frames.shape

        train_seq_len = int(self.transformer.seq_len)
        T_new = 1 + (2*T - 1)
        # === RoPE extension mode & KV caches ===
        if context_mode in ("pi","ntk"):
            scale = max(1.0, float(T_new) / float(train_seq_len))
            self.transformer.set_context_extension(context_mode, factor=scale)
        elif context_mode == "sliding":
            self.transformer.set_context_extension("sliding", window_size=train_seq_len)
        else:
            self.transformer.set_context_extension("none")
        
        max_batch_size = B
        max_batch_size_cfg = max_batch_size * (2 if trans_cfg_scale > 1.0 and not audio_pretraining else 1)
        
        infer_dtype = next(self.transformer.parameters()).dtype
        with torch.device(device):
            self.transformer.setup_caches(
                max_batch_size=max_batch_size_cfg, 
                max_seq_length=T_new, 
                dtype=infer_dtype
            )

        # create an empty tensor of the expected final shape and fill in the current tokens
        out_dtype = infer_dtype
        seq = torch.empty((B, T, self.config.audio_embed_dim), dtype=out_dtype, device=device)
        
        online_dec: OnlineCausalWaveformDecoder
        if not hasattr(self, "_online_dec") or self._online_dec is None:
            self._online_dec = OnlineCausalWaveformDecoder(
                stage1_model=stage1_model,
                z_mean=z_mean,
                z_std=z_std,
                batch_size=B,
                total_T=T,
                device=device,
                use_cuda_graph=True,
                dtype_latents=out_dtype,
            )
        else:
            self._online_dec.reset_buffers(total_T=T)
        online_dec = self._online_dec
        

        ### init token decoding
        pos = torch.tensor([0], device=device, dtype=torch.long)
        z = self.bos_sampling(
            self.transformer.init_audio_token.expand(max_batch_size_cfg, -1),
            type_id=int(1),
            input_pos=pos,
            trans_cfg_scale=trans_cfg_scale,
            inference_noise=inference_noise,
            audio_pretraining=audio_pretraining,
        )
        pos = pos + 1
        prev_grid_pca = None
        
        frame0_gpu = frames[:, 0].to(device, non_blocking=True) 
        cond_t0, prev_grid_pca = self._build_cond_token_for_frame(
            dino=dino,
            frame_bchw=frame0_gpu,          # [B,3,H,W]
            prev_grid_pca=prev_grid_pca,
            use_vision_aggregation=vision_aggregation,
            infer_dtype=infer_dtype,
        )

        # V(cond_t0)
        v_init_token = self.transformer_sampling(
            cond_t0, type_id=int(0),
            input_pos=pos,
            trans_cfg_scale=trans_cfg_scale,
            inference_noise=inference_noise,
            audio_pretraining=audio_pretraining,
        )
        
        # concat with BOS(audio)
        z_cat = torch.cat([z, v_init_token], dim=-1)

        # diffusion -> A_0
        a0 = self.diffusion_sampling(
            z_cat,
            cfg=diff_cfg_scale,
            cfg_schedule=cfg_schedule,
            temperature=temperature,
            device=z_cat.device,
            net_autoguidance=net_autoguidance,
            **edm_kwargs
        )
        seq[:, 0] = a0
        pos = pos + 1
        _ = online_dec.push_token(a0)
        
        lat = None
        if measure_latency:
            lat = {"frame_ms": []}
        # ==== t=1..T-1 ====
        cur_token = a0.view(B, 1, -1)
        if measure_latency and torch.cuda.is_available():
            torch.cuda.synchronize()
        # t_frame_start = time.perf_counter()
        for t in range(1, T):
            t_frame_start = time.perf_counter()

            # audio step (A_{t-1} -> z_out)
            z_out = self.transformer_sampling(
                cur_token, type_id=int(1),
                input_pos=pos,
                trans_cfg_scale=trans_cfg_scale,
                inference_noise=inference_noise,
                audio_pretraining=audio_pretraining,
            )
            pos = pos + 1

            frame_t_gpu = frames[:, t].to(device, non_blocking=True)
            cond_t, prev_grid_pca = self._build_cond_token_for_frame(
                dino=dino,
                frame_bchw=frame_t_gpu,
                prev_grid_pca=prev_grid_pca,
                use_vision_aggregation=vision_aggregation,
                infer_dtype=infer_dtype,
            )

            v_out = self.transformer_sampling(
                cond_t, type_id=int(0),
                input_pos=pos,
                trans_cfg_scale=trans_cfg_scale,
                inference_noise=inference_noise,
                audio_pretraining=audio_pretraining,
            )
            pos = pos + 1

            # [z_out, v_out] -> diffusion -> A_t
            z_token = torch.cat([z_out, v_out], dim=-1)
            a_t = self.diffusion_sampling(
                z_token,
                cfg=diff_cfg_scale,
                cfg_schedule=cfg_schedule,
                temperature=temperature,
                device=z_token.device,
                net_autoguidance=net_autoguidance,
                **edm_kwargs
            )
            seq[:, t] = a_t
            cur_token = a_t.view(B, 1, -1)
            _ = online_dec.push_token(a_t)
            
            if measure_latency:
                torch.cuda.synchronize(device)
            t_gen_end = time.perf_counter()
            frame_ms = ((t_gen_end - t_frame_start)) * 1000.0
            lat["frame_ms"].append(frame_ms)
        
        online_dec.finalize()
        N_emit = online_dec.emitted_samples
        waveform = online_dec.audio[:, :, :N_emit]
        if measure_latency:
            return seq, _, lat
        return seq

    @torch.no_grad()
    def _build_cond_token_for_frame(
        self,
        dino: "DINOv2StreamingEncoder",
        frame_bchw: torch.Tensor,
        prev_grid_pca: torch.Tensor | None,
        *,
        infer_dtype: torch.dtype, 
        use_vision_aggregation: bool=True, 
    ):
        """
        Per-frame vision path (vision_aggregation=True, film_fuse=False)
        Pipeline:
            frame -> DINO grid -> (downsample) -> PCA -> Δ(grid)
                -> concat([grid, Δ]) -> 3D Conv (T=1) -> flatten (H*W tokens)
                -> Aggregator (num_agg queries) -> cond_t

        Returns:
            cond_t: [B, D_cond]
            new_prev_grid_pca: [B, h', w', C_pca]
        """
        assert use_vision_aggregation, "This simplified implementation requires vision_aggregation=True."

        B = frame_bchw.shape[0]
        grid_pca, delta_pca = dino.get_grid_pca(frame_bchw, prev_grid_pca)
        new_prev_grid_pca = grid_pca
        
        grid_cat = torch.cat([grid_pca, delta_pca], dim=-1)
        # if grid_cat.dtype is not infer_dtype:
        #     grid_cat = grid_cat.to(dtype=infer_dtype)

        v3d_in = grid_cat.permute(0, 3, 1, 2).unsqueeze(2).contiguous(memory_format=torch.channels_last_3d)
        # v3d_in = grid_cat.permute(0, 3, 1, 2).unsqueeze(2).contiguous()
        v3d_out = self.video_proj_3d(v3d_in) # [B, D_agg, 1, H2, W2]

        flat_tokens = rearrange(v3d_out, "b c t h w -> (b t) (h w) c")
        # queries = self.aggregated_tokens.to(dtype=v3d_out.dtype, device=v3d_out.device).unsqueeze(0).expand(B, -1, -1)
        queries = self.aggregated_tokens.unsqueeze(0).expand(B, -1, -1)
        agg_out = self.aggregate_transformer(flat_tokens, queries)      # [B, num_agg, D_agg]

        if agg_out.shape[1] == 1:
            cond_t = agg_out[:, 0, :]  # [B, D_agg]
        else:
            cond_t = agg_out.reshape(B, -1)  # [B, num_agg*D_agg]
        # if cond_t.dtype is not infer_dtype:
        #     cond_t = cond_t.to(dtype=infer_dtype)
        return cond_t, new_prev_grid_pca


def gnse_default(**kwargs):
    model = GNSE(ModelArgs(d_model=1024, num_layers=18, num_heads=16, head_channel=1280, num_head_block=8, **kwargs)) # 300M
    return model

def gnse_default_2_50head(**kwargs):
    model = GNSE(ModelArgs(d_model=1024, num_layers=19, num_heads=16, head_channel=1024, num_head_block=8, **kwargs)) # 300M
    return model

def gnse_default_2_30head(**kwargs):
    model = GNSE(ModelArgs(d_model=1024, num_layers=20, num_heads=16, head_channel=768, num_head_block=8, **kwargs)) # 300M
    return model

def gnse_default_2_10head(**kwargs):
    model = GNSE(ModelArgs(d_model=1024, num_layers=22, num_heads=16, head_channel=512, num_head_block=6, **kwargs)) # 300M
    return model

GNSE_models = {
    'GNSE-T':gnse_default, 
    'GNSE-T-50':gnse_default_2_50head, 
    'GNSE-T-30':gnse_default_2_30head, 
    'GNSE-T-10':gnse_default_2_10head, 
}

