
"""
v2:add whisper encoder for semantic information extraction
v3:replace llama2-7b with llama3.1-8b -- only llama2 supports Chinese!!

adding universal projection module
    replace one mlp to moe-mlp

adding task-aware embedding
"""

import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_model import BaseModel
from .spatial_encoder import (
    SpatialEncoderUnavailable,
    get_spatial_encoder_class,
)

from peft import LoraConfig, TaskType, get_peft_model
from transformers import LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast
from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor
from transformers import AutoTokenizer, AutoModelForCausalLM, Qwen2ForCausalLM
from typing import Optional

class AudioContextExtractor(nn.Module):
    """从单通道音频提取上下文特征用于Router"""
    def __init__(self, input_dim=1, hidden_dim=128, output_dim=256):
        super().__init__()
        # CNN部分：1D卷积提取时域特征
        self.cnn = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dim//2, kernel_size=512, stride=256, padding=256),
            nn.BatchNorm1d(hidden_dim//2),
            nn.ReLU(),
            nn.Conv1d(hidden_dim//2, hidden_dim, kernel_size=16, stride=8, padding=8),  
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(64)  # 固定长度输出
        )
        # MLP部分：特征压缩
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim * 64, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim)
        )
    
    def forward(self, audio):
        # audio: [B, samples] 单通道音频
        if len(audio.shape) == 2:
            audio = audio.unsqueeze(1)  # [B, 1, samples]
        elif len(audio.shape) == 3:
            audio = audio  # 已经是 [B, 1, samples]
        
        # CNN特征提取
        features = self.cnn(audio)  # [B, hidden_dim, 64]
        features = features.reshape(features.shape[0], -1)  # [B, hidden_dim * 64]
        
        # MLP处理
        context = self.mlp(features)  # [B, output_dim=256]
        return context

class Expert(nn.Module):
    """标准FFN专家网络"""
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.ffn(x)

class Router(nn.Module):
    """MoE路由器：audio_context + prompt_embedding → expert weights"""
    def __init__(self, audio_ctx_dim=256, prompt_embed_dim=3584, num_experts=3):
        super().__init__()
        self.fc = nn.Sequential(
            nn.LayerNorm(audio_ctx_dim + prompt_embed_dim),
            nn.Linear(audio_ctx_dim + prompt_embed_dim, 512),
            nn.ReLU(), 
            nn.Linear(512, num_experts)
        )
    
    def forward(self, audio_ctx, prompt_embed):
        # audio_ctx: [B, audio_ctx_dim] 
        # prompt_embed: [B, prompt_embed_dim] - 需要池化到batch维度
        if len(prompt_embed.shape) == 3:
            prompt_embed = prompt_embed.mean(dim=1)  # [B, seq_len, dim] → [B, dim]
        
        router_input = torch.cat([audio_ctx, prompt_embed], dim=-1)  # [B, audio_ctx_dim + prompt_embed_dim]
        logits = self.fc(router_input)  # [B, num_experts]
        # weights = torch.nn.functional.softmax(logits, dim=-1)  # [B, num_experts]
        weights = torch.sigmoid(logits)
        return logits, weights

class MoELayer(nn.Module):
    """完整的MoE层：Router控制4个模态专用专家, 1个固定联合专家"""
    def __init__(self, whisper_dim=768, spatial_dim=768, combined_dim=768, 
                 prompt_dim=3584, output_dim=3584):
        super().__init__()
        self.output_dim = output_dim
        
        # Router控制的模态专家 (5个)
        self.whisper_expert = Expert(whisper_dim, 2048, output_dim)        # Whisper专用
        self.spatial_experts = nn.ModuleList([                            # Spatial专用 (4个)
            Expert(spatial_dim, 2048, output_dim),
            Expert(spatial_dim, 2048, output_dim), 
            Expert(spatial_dim, 2048, output_dim),
            Expert(spatial_dim, 2048, output_dim)
        ])
        
        # 固定的联合特征专家 (1个)
        self.combined_expert = Expert(combined_dim, 2048, output_dim)
        
        # Router
        self.router = Router(
            audio_ctx_dim=256,
            prompt_embed_dim=prompt_dim,
            num_experts=5  # 1 whisper + 4 spatial
        )
    
    def forward(self, whisper_embeds, spatial_embeds, combined_embeds, 
                audio_context, prompt_embeds, router_label=None, teacher_forcing_ratio=0.0):
        """
        Args:
            whisper_embeds: [B, 1500, 768]
            spatial_embeds: [B, 1500, 768] 
            combined_embeds: [B, 1500, 1536]
            audio_context: [B, 256]
            prompt_embeds: [B, seq_len, 3584]
        Returns:
            output: [B, 1500, 3584]
        """
        expert_outputs = []
        
        # 1. 固定的联合特征专家
        combined_output = self.combined_expert(combined_embeds)
        expert_outputs.append(combined_output)
        
        # 2. Router控制的模态专家
        router_logits, predicted_weights = self.router(audio_context, prompt_embeds)  # [B, 5]

        final_weights = predicted_weights

        if self.training and router_label is not None and random.random() < teacher_forcing_ratio:
            # Teacher Forcing: 使用真实标签
            final_weights = router_label.float()
        
        modality_outputs = []
        # Whisper expert
        modality_outputs.append(self.whisper_expert(whisper_embeds))
        # Spatial experts
        for expert in self.spatial_experts:
            modality_outputs.append(expert(spatial_embeds))
            
        # 加权融合
        modality_outputs = torch.stack(modality_outputs, dim=2)  # [B, 1500, 5, 3584]
        router_weights_expanded = final_weights.unsqueeze(1).unsqueeze(-1)  # [B, 1, 5, 1]
        router_combined = torch.sum(modality_outputs * router_weights_expanded, dim=2)  # [B, 1500, 3584]
        expert_outputs.append(router_combined)
        
        # 所有专家输出简单平均融合
        final_output = torch.stack(expert_outputs, dim=0).mean(dim=0)  # [B, 1500, 3584]
            
        return final_output, router_logits

class TWNM(BaseModel):
    def __init__(self, config, spatial_encoder_ckpt_path: Optional[str] = "assets/checkpoints/spatial_encoder/loss=0.4612.ckpt", lora_pretrain_ckpt_path = None, is_inference=False):
        super().__init__(config)

        self.teacher_forcing_ratio = config.get("teacher_forcing_ratio", 0.5) 
        if self.teacher_forcing_ratio > 0:
            print(f"Teacher forcing for MoE router is enabled with ratio: {self.teacher_forcing_ratio}")

        # 现阶段冻结 - 使用在线模型
        model_name = 'openai/whisper-small'
        self.whisper_feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
        self.whisper = WhisperForConditionalGeneration.from_pretrained(model_name)
        for p in self.whisper.model.parameters():
            p.requires_grad = False

        # 添加spatial_encoder并冻结参数
        try:
            SpatialEncoderCls = get_spatial_encoder_class()
        except SpatialEncoderUnavailable as exc:
            raise SpatialEncoderUnavailable(
                "TWNM 初始化失败：未检测到可用的 SpatialEncoder 实现。"
            ) from exc

        self.spatial_encoder = SpatialEncoderCls(
            dim_input=4,  # 双通道复数STFT
            dim_hidden=96,
            num_layers=8,
            # 其他参数使用默认值
        )

        self.whisper_norm = nn.LayerNorm(768)
        self.spatial_norm = nn.LayerNorm(768)

        # 3. 加载权重文件 (建议使用 map_location='cpu' 避免GPU显存问题)
        checkpoint = torch.load(spatial_encoder_ckpt_path, map_location='cpu')
        state_dict = checkpoint['state_dict']

        # 创建一个新字典，用于存放处理后的权重
        new_state_dict = {}
        prefix_to_remove = "model." # <--- 根据你看到的实际前缀修改

        for key, value in state_dict.items():
            if key.startswith(prefix_to_remove):
                new_key = key[len(prefix_to_remove):] # 去掉前缀
                new_state_dict[new_key] = value
            else:
                new_state_dict[key] = value # 如果没有前缀，直接保留


        # 4. 将权重载入模型
        self.spatial_encoder.load_state_dict(new_state_dict)

        # 冻结spatial_encoder参数
        for p in self.spatial_encoder.parameters():
            p.requires_grad = False

        # decoder
        '''
        hf_token = "your huggingface token"
        self.tokenizer = AutoTokenizer.from_pretrained(
            "meta-llama/Llama-2-7b", token=hf_token
        )
        self.tokenizer.pad_token = self.tokenizer.unk_token
        self.decoder = AutoModelForCausalLM.from_pretrained(
            "meta-llama/Llama-2-7b", token=hf_token
        )
        '''
        # 使用Qwen2模型
        decoder_name = './assets/checkpoints/qwen2-audio-llm-extracted'  # 使用较小的Qwen2模型进行测试
        
        self.decoder =  Qwen2ForCausalLM.from_pretrained(
            decoder_name, 
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
            # attn_implementation="flash_attention_2"
        )
        self.decoder.gradient_checkpointing_enable()
        
        self.tokenizer = AutoTokenizer.from_pretrained(decoder_name, trust_remote_code=True)

        new_special_tokens = {
            "additional_special_tokens": ["|<think>|", "|</think>|", "|<answer>|", "|</answer>|"]
        }
        
        self.tokenizer.add_special_tokens(new_special_tokens)
        self.decoder.resize_token_embeddings(len(self.tokenizer))

        # 添加特殊token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        if self.tokenizer.bos_token is None:
            self.tokenizer.bos_token = self.tokenizer.eos_token

        # 二阶段训练时冻结大模型 不使用LLM
        peft_config = LoraConfig(
            target_modules=["q_proj", "v_proj"],
            task_type=TaskType.CAUSAL_LM,
            inference_mode=is_inference,
            r=8,
            lora_alpha=32,
            lora_dropout=0.1,
        )
        # self.apply_decoder_strategy(peft_config)
        self.decoder = get_peft_model(self.decoder, peft_config)
        
        # 冻结decoder，只训练MoE部分
        # for p in self.decoder.parameters():
        #     p.requires_grad = False

        # 添加spatial encoder维度对齐投影层
        self.spatial_proj = nn.Linear(129 * 192, 768)  # 24768 → 768，将F*H投影到whisper维度
        
        # 添加MoE结构组件
        self.audio_context_extractor = AudioContextExtractor()
        self.moe_layer = MoELayer(
            whisper_dim=768,
            spatial_dim=768, 
            combined_dim=768,
            prompt_dim=self.decoder.config.hidden_size,
            output_dim=self.decoder.config.hidden_size
        )

        if lora_pretrain_ckpt_path is not None and lora_pretrain_ckpt_path != "none":
            print(f"--- Loading weights from pre-LoRA checkpoint: {lora_pretrain_ckpt_path} ---")
            
            # 以CPU模式加载旧的 state_dict，避免占用GPU显存
            pretrain_state_dict = torch.load(lora_pretrain_ckpt_path, map_location="cpu")
            
            # 使用 strict=False 加载权重
            # 这是最关键的一步
            missing_keys, unexpected_keys = self.load_state_dict(pretrain_state_dict, strict=False)
            
            print("--- Weights loaded with strict=False ---")
            
            # 打印加载信息，用于调试
            if missing_keys:
                print(f"[INFO] Missing keys ({len(missing_keys)}): These are new parameters in your LoRA model that were not in the checkpoint (this is expected).")
                # 只打印前5个缺失的键，避免刷屏
                print("       Examples:", missing_keys[:5])

            if unexpected_keys:
                print(f"[WARNING] Unexpected keys ({len(unexpected_keys)}): These keys were in the checkpoint but are not in the current model.")
                print("         Examples:", unexpected_keys[:5])

        # Convert trainable components to bfloat16 to match the decoder
        self.spatial_proj.to(torch.bfloat16)
        self.audio_context_extractor.to(torch.bfloat16)
        self.moe_layer.to(torch.bfloat16)
        self.whisper_norm.to(torch.bfloat16)
        self.spatial_norm.to(torch.bfloat16)


    def print_module_parameters(self):
        whisper_num_params = sum([i.numel() for i in self.whisper.parameters()])
        decoder_num_params = sum([i.numel() for i in self.decoder.parameters()])
        spatial_encoder_num_params = sum([i.numel() for i in self.spatial_encoder.parameters()])
        
        # MoE组件参数统计
        spatial_proj_params = sum([i.numel() for i in self.spatial_proj.parameters()])
        audio_ctx_params = sum([i.numel() for i in self.audio_context_extractor.parameters()])
        moe_layer_params = sum([i.numel() for i in self.moe_layer.parameters()])
        
        total_moe_params = spatial_proj_params + audio_ctx_params + moe_layer_params
        
        print(f"=== 模型参数统计 ===")
        print(f"Whisper Encoder: {whisper_num_params:,}")
        print(f"Spatial Encoder: {spatial_encoder_num_params:,}")
        print(f"Qwen2 Decoder: {decoder_num_params:,}")
        print(f"--- MoE组件 ---")
        print(f"Spatial投影层: {spatial_proj_params:,}")
        print(f"Audio上下文提取器: {audio_ctx_params:,}")
        print(f"MoE层(Router+Experts): {moe_layer_params:,}")
        print(f"MoE总参数: {total_moe_params:,}")
        print(f"=== 总参数量: {whisper_num_params + spatial_encoder_num_params + decoder_num_params + total_moe_params:,} ===")

    def prepare_inputs_labels_for_multimodal(
        self, audio_embeds, atts, prompt, text=None
    ):
        prompt_left = []
        prompt_right = []
        for i, p in enumerate(prompt):
            l, r = p.split("<AcousticTokens>")
            prompt_left.append(self.tokenizer.bos_token + l)
            prompt_right.append(r)

        prompt_left_tokens = self.tokenizer(
            prompt_left, 
            add_special_tokens=False, 
            padding="longest",
            return_tensors="pt"
        ).to(audio_embeds.device)
        prompt_left_embeds = self.decoder.get_input_embeddings()(
            prompt_left_tokens.input_ids.long()
        )

        prompt_right_tokens = self.tokenizer(
            prompt_right,
            add_special_tokens=False,
            padding="longest",
            return_tensors="pt",
        ).to(audio_embeds.device)
        prompt_right_embeds = self.decoder.get_input_embeddings()(
            prompt_right_tokens.input_ids.long()
        )

        input_embeds = torch.cat(
            [prompt_left_embeds, audio_embeds, prompt_right_embeds], dim=1
        )
        input_mask = torch.cat(
            [
                prompt_left_tokens.attention_mask,
                atts,
                prompt_right_tokens.attention_mask,
            ],
            dim=1,
        )

        decoder_targets = None
        if text is not None:
            new_text = []
            for t in text:
                new_text.append(t + self.tokenizer.eos_token)  # </s> is the eos_token
            text_tokens = self.tokenizer(
                new_text,
                add_special_tokens=False,
                padding="longest",
                return_tensors="pt",
            ).to(audio_embeds.device)
            text_embeds = self.decoder.get_input_embeddings()(text_tokens.input_ids.long())

            targets = text_tokens.input_ids.masked_fill(
                text_tokens.attention_mask == 0, -100
            )

            empty_targets = (
                torch.ones([input_mask.shape[0], input_mask.shape[1]], dtype=torch.long)
                .to(audio_embeds.device)
                .fill_(-100)
            )
            decoder_targets = torch.cat([empty_targets, targets], dim=1)

            input_embeds = torch.cat([input_embeds, text_embeds], dim=1)
            input_mask = torch.cat([input_mask, text_tokens.attention_mask], dim=1)
        

        return input_embeds, input_mask, decoder_targets

    def get_prompt_embeds(self, prompt, audios):
        prompt_left = []
        prompt_right = []
        for i, p in enumerate(prompt):
            l, r = p.split("<AcousticTokens>")
            prompt_left.append(self.tokenizer.bos_token + l)
            prompt_right.append(r)

        prompt_left_tokens = self.tokenizer(
            prompt_left, 
            add_special_tokens=False, 
            padding="longest",
            return_tensors="pt"
        ).to(audios.device)
        prompt_left_embeds = self.decoder.get_input_embeddings()(
            prompt_left_tokens.input_ids.long()
        )

        prompt_right_tokens = self.tokenizer(
            prompt_right,
            add_special_tokens=False,
            padding="longest",
            return_tensors="pt",
        ).to(audios.device)
        prompt_right_embeds = self.decoder.get_input_embeddings()(
            prompt_right_tokens.input_ids.long()
        )

        prompt_embeds = torch.cat([prompt_left_embeds, prompt_right_embeds], dim=1)
        return prompt_embeds        

    def forward_encoder(self, audios, prompt=None, router_label=None):
        # audios: [B, 2, samples] 双通道音频
        
        # 1. Spatial encoder处理双通道音频
        spatial_embeds = self.spatial_encoder.forward_as_encoder(audios)  # [B, 129, T, 192]
        # 确保spatial_embeds为bfloat16类型
        spatial_embeds = spatial_embeds.to(torch.bfloat16)
        
        # 2. Whisper处理单通道音频（双通道平均）
        audios_mono = audios.mean(dim=1)  # [B, samples]
        # inputs = torch.chunk(audios_mono, chunks=audios_mono.shape[0], dim=0)
        # inputs = [item.squeeze(0).cpu().numpy() for item in inputs]
        audios_mono_cpu_numpy = audios_mono.to(torch.float32).cpu().numpy()
        mels = self.whisper_feature_extractor(audios_mono_cpu_numpy, sampling_rate=16000, return_tensors='pt')['input_features'].to(audios.device)
        whisper_embeds = self.whisper.model.encoder(mels)['last_hidden_state']  # [B, 1500, 768]
        # 确保whisper_embeds为bfloat16类型
        whisper_embeds = whisper_embeds.to(torch.bfloat16)
        
        # 3. Spatial维度对齐
        B, F, T, H = spatial_embeds.shape
        spatial_reshaped = spatial_embeds.reshape(B, T, F * H)  # [B, T, 24768]
        spatial_projected = self.spatial_proj(spatial_reshaped)  # [B, T, 768]
        
        # 4. 时间维度对齐到1500
        # [B, T, 768] -> [B, 768, T]
        temp_transposed = spatial_projected.transpose(1, 2)
        # 使用 interpolate 进行重采样，功能上等价于 adaptive_avg_pool1d
        resampled_transposed = torch.nn.functional.interpolate(
            temp_transposed,
            size=1500,           # 目标长度
            mode='linear',       # 使用线性插值模式
            align_corners=False
        )
        # [B, 768, 1500] -> [B, 1500, 768]
        spatial_aligned = resampled_transposed.transpose(1, 2)

        whisper_embeds = self.whisper_norm(whisper_embeds)
        spatial_aligned = self.spatial_norm(spatial_aligned)

        # 5. 融合两个encoder输出
        combined_embeds = torch.add(spatial_aligned, whisper_embeds)  # [B, 1500, 768]
        
        # 6. 提取audio context用于Router
        # 确保audios_mono为bfloat16类型
        audios_mono = audios_mono.to(torch.bfloat16)
        audio_context = self.audio_context_extractor(audios_mono)
        
        # --- Temporarily generate prompt_embeds for router ---
        prompt_embeds = self.get_prompt_embeds(prompt, audios)

        # 7. 通过MoE层处理
        moe_output, router_logits = self.moe_layer(
            whisper_embeds=whisper_embeds,
            spatial_embeds=spatial_aligned, 
            combined_embeds=combined_embeds,
            audio_context=audio_context,
            prompt_embeds=prompt_embeds,
            router_label=router_label,
            teacher_forcing_ratio=self.teacher_forcing_ratio
        )  # [B, 1500, 3584]
        
        return moe_output, router_logits

    def forward(self, samples):
        audios = samples["audios"]
        text = samples["text"]
        # router_label = samples["router_label"].to(audios.device)
        router_label = None

        if "task" in samples:
            task = samples["task"]
        else:
            task = ["AAC"] * audios.shape[0]
        
        # For QA task, task list contains instructions. Format them into prompts.
        prompt = [t + " <AcousticTokens>" for t in task]

        # encoder - 现在通过MoE处理两个encoder的输出
        encoder_hidden_states, router_logits = self.forward_encoder(audios, prompt, router_label)  # [B, 1500, 3584]

        encoder_atts = torch.ones(
            encoder_hidden_states.size()[:-1], dtype=torch.long
        ).to(encoder_hidden_states.device)

        input_embeds, input_mask, decoder_targets = (
            self.prepare_inputs_labels_for_multimodal(
                encoder_hidden_states, encoder_atts, prompt, text
            )
        )

        decoder_output = self.decoder(
            input_ids=None,
            inputs_embeds=input_embeds,
            attention_mask=input_mask,
            labels=decoder_targets,
            return_dict=True,
        )

        ce_loss = decoder_output.loss        

        # Calculate router loss
        # router_loss_fn = nn.BCEWithLogitsLoss()
        # router_loss = router_loss_fn(router_logits, router_label.float())

        # total_loss = ce_loss + router_loss
        total_loss = ce_loss

        return {
            "loss": total_loss,
            "ce_loss": ce_loss,
            # "router_loss": router_loss,
            "logits": decoder_output.logits,
        }

    def generate(
        self,
        samples,
        use_nucleus_sampling=False,
        num_beams=3,
        max_length=30,
        min_length=2,
        top_p=0.9,
        repetition_penalty=1.0,
    ):
        audios = samples["audios"].to(self.device)

        if "task" in samples:
            task = samples["task"]
        else:
            task = ["AAC"] * audios.shape[0]

        # For QA task, task list contains instructions. Format them into prompts.
        prompt = [t + " <AcousticTokens>" for t in task]

        # encoder - 现在通过MoE处理两个encoder的输出  
        encoder_hidden_states, _ = self.forward_encoder(audios, prompt)  # [B, 1500, 3584]
        
        encoder_atts = torch.ones(
            encoder_hidden_states.size()[:-1], dtype=torch.long
        ).to(encoder_hidden_states.device)
        
        input_embeds, input_mask, decoder_targets = (
            self.prepare_inputs_labels_for_multimodal(
                encoder_hidden_states, encoder_atts, prompt
            )
        )

        outputs = self.decoder.generate(
            inputs_embeds=input_embeds,
            attention_mask=input_mask,
            max_new_tokens=max_length,
            min_new_tokens=min_length,
            do_sample=use_nucleus_sampling,
            top_p=top_p,
            temperature=1.0,
            num_beams=num_beams,
            repetition_penalty=repetition_penalty,
            eos_token_id=151643
        )
        print(outputs)
        # captions = self.tokenizer.batch_decode(outputs, add_special_tokens=False)
        captions = self.tokenizer.batch_decode(outputs, add_special_tokens=True)
        return captions

if __name__ == "__main__":
    config = {
        "encoder_conf": {
            "encoder_strategy" : "lora"
        },
        "decoder_conf": {
            "decoder_strategy" : "lora"
        }
    }
    model = TWNM(config)
    inputs = torch.rand(4, 2, 16000*30)  # 修改为双通道
    sample = {
        "audios": inputs,
        "text": ["testing audio content", "another test audio", "music sample", "speech recording"]
    }
    total_loss, logits = model.forward(sample)
    print(f'total_loss: {total_loss.item():.4f}')
    print(f'logits shape: {logits.shape}')
