import random
from typing import Optional
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from contextlib import contextmanager
import os

from .spatial_encoder import (
    SpatialEncoderUnavailable,
    get_spatial_encoder_class,
)
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from transformers import (
    AutoTokenizer, 
    PretrainedConfig, 
    PreTrainedModel,
    Qwen2ForCausalLM, 
    WhisperFeatureExtractor,
    WhisperForConditionalGeneration, 
    BitsAndBytesConfig
)
import logging

# (前面的所有类定义 AudioContextExtractor, Expert, Router, MoELayer, TWNMConfig 保持不变)
class AudioContextExtractor(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=128, output_dim=256):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dim//2, kernel_size=512, stride=256, padding=256), 
            nn.BatchNorm1d(hidden_dim//2), 
            nn.ReLU(), 
            nn.Conv1d(hidden_dim//2, hidden_dim, kernel_size=16, stride=8, padding=8), nn.BatchNorm1d(hidden_dim), 
            nn.ReLU(), nn.AdaptiveAvgPool1d(64)
        )
        self.mlp = nn.Sequential(nn.Linear(hidden_dim * 64, output_dim), nn.ReLU(), nn.Linear(output_dim, output_dim))
    def forward(self, audio):
        if len(audio.shape) == 2: audio = audio.unsqueeze(1)
        features = self.cnn(audio); features = features.reshape(features.shape[0], -1)
        return self.mlp(features)

class Expert(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.ffn = nn.Sequential(nn.LayerNorm(input_dim), nn.Linear(input_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, output_dim))
    def forward(self, x): return self.ffn(x)

class Router(nn.Module):
    def __init__(self, audio_ctx_dim=256, prompt_embed_dim=4096, num_experts=5):
        super().__init__()
        self.fc = nn.Sequential(nn.LayerNorm(audio_ctx_dim + prompt_embed_dim), nn.Linear(audio_ctx_dim + prompt_embed_dim, 512), nn.ReLU(), nn.Linear(512, num_experts))
    def forward(self, audio_ctx, prompt_embed):
        if len(prompt_embed.shape) == 3: prompt_embed = prompt_embed.mean(dim=1)
        router_input = torch.cat([audio_ctx, prompt_embed], dim=-1)
        logits = self.fc(router_input)
        return logits, torch.sigmoid(logits)

class MoELayer(nn.Module):
    def __init__(self, whisper_dim=768, spatial_dim=768, combined_dim=768, prompt_dim=4096, output_dim=4096):
        super().__init__()
        self.whisper_expert = Expert(whisper_dim, 2048, output_dim)
        self.spatial_experts = nn.ModuleList([Expert(spatial_dim, 2048, output_dim) for _ in range(4)])
        self.combined_expert = Expert(combined_dim, 2048, output_dim)
        self.router = Router(prompt_embed_dim=prompt_dim)
    def forward(self, whisper_embeds, spatial_embeds, combined_embeds, audio_context, prompt_embeds):
        expert_outputs = [self.combined_expert(combined_embeds)]
        # 路由改为等权重，退化为纯 MLP 融合
        router_logits = None
        num_modalities = 1 + len(self.spatial_experts)  # 1 whisper + 4 spatial
        final_weights = torch.full(
            (whisper_embeds.size(0), num_modalities),
            1.0 / num_modalities,
            device=whisper_embeds.device,
            dtype=whisper_embeds.dtype,
        )

        modality_outputs = [self.whisper_expert(whisper_embeds)]
        modality_outputs.extend([expert(spatial_embeds) for expert in self.spatial_experts])
        modality_outputs = torch.stack(modality_outputs, dim=2)
        router_weights_expanded = final_weights.unsqueeze(1).unsqueeze(-1)
        router_combined = torch.sum(modality_outputs * router_weights_expanded, dim=2)
        expert_outputs.append(router_combined)
        return torch.stack(expert_outputs, dim=0).mean(dim=0), router_logits

class TWNMConfig(PretrainedConfig):
    model_type = "twnm"
    def __init__(self, whisper_model_name: str = 'openai/whisper-small', decoder_model_name: str = './assets/checkpoints/qwen2-audio-llm-extracted', spatial_encoder_ckpt_path: Optional[str] = "assets/checkpoints/spatial_encoder/loss=0.4612.ckpt", lora_r: int = 8, lora_alpha: int = 32, lora_dropout: float = 0.1, **kwargs):
        super().__init__(**kwargs)
        self.whisper_model_name, self.decoder_model_name, self.spatial_encoder_ckpt_path, self.lora_r, self.lora_alpha, self.lora_dropout = whisper_model_name, decoder_model_name, spatial_encoder_ckpt_path, lora_r, lora_alpha, lora_dropout

class TWNM(PreTrainedModel):
    config_class = TWNMConfig
    _supports_gradient_checkpointing = True

    def __init__(self, config: TWNMConfig, quantization_config: Optional[BitsAndBytesConfig] = None, peft_config: Optional[LoraConfig] = None):
        super().__init__(config)
        self.config = config
        self.gradient_checkpointing = False
        self.whisper_feature_extractor = WhisperFeatureExtractor.from_pretrained(config.whisper_model_name)
        self.whisper = WhisperForConditionalGeneration.from_pretrained(config.whisper_model_name).eval()
        for p in self.whisper.parameters(): p.requires_grad = False
        try:
            SpatialEncoderCls = get_spatial_encoder_class()
        except SpatialEncoderUnavailable as exc:
            raise SpatialEncoderUnavailable(
                "TWNM 初始化失败：未检测到可用的 SpatialEncoder 实现。"
            ) from exc

        self.spatial_encoder = SpatialEncoderCls(dim_input=4, dim_hidden=96, num_layers=8).eval()

        checkpoint = torch.load(config.spatial_encoder_ckpt_path, map_location='cpu')
        state_dict = checkpoint['state_dict']

        # 创建一个新字典，用于存放处理后的权重
        new_state_dict = {}
        prefix_to_remove = "model." # <--- 根据你看到的实际前缀修改

        for key, value in state_dict.items():
            if key.startswith(prefix_to_remove):
                new_key = key[len(prefix_to_remove):] # 去掉前缀
                new_state_dict[new_key] = value
            else:
                new_state_dict[key] = value # 如果没有前缀，直接保留
        self.spatial_encoder.load_state_dict(new_state_dict, strict=False)

        for p in self.spatial_encoder.parameters(): 
            p.requires_grad = False

        self.audio_context_extractor = AudioContextExtractor()
        self.decoder = Qwen2ForCausalLM.from_pretrained(
            config.decoder_model_name, 
            trust_remote_code=True, 
            quantization_config=quantization_config,
        )

        self.moe_layer = MoELayer(
            whisper_dim=768,
            spatial_dim=768, 
            combined_dim=768,
            prompt_dim=self.decoder.config.hidden_size,
            output_dim=self.decoder.config.hidden_size
        )

        self.spatial_proj = nn.Linear(129 * 192, 768)
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            config.decoder_model_name, 
            trust_remote_code=True
        )
        new_special_tokens = {"additional_special_tokens": ["|<think>|", "|</think>|", "|<answer>|", "|</answer>|"]}

        if self.tokenizer.add_special_tokens(new_special_tokens) > 0:
            self.decoder.resize_token_embeddings(len(self.tokenizer))

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        if self.tokenizer.bos_token is None:
            self.tokenizer.bos_token = self.tokenizer.eos_token

        self.decoder = prepare_model_for_kbit_training(self.decoder)

        print("Freezing parameters for custom encoder modules (spatial_proj, audio_context_extractor, moe_layer)...")
        for module in [self.spatial_proj, self.audio_context_extractor, self.moe_layer]:
            for param in module.parameters():
                param.requires_grad = False
        print("Custom encoder modules frozen successfully.")

        if peft_config is not None:
            self.decoder = get_peft_model(self.decoder, peft_config)
            print("Successfully attached new LoRA adapters for GRPO training.")

    def change_to_policy(self):
        if not hasattr(self.decoder, 'add_adapter'):
            raise ValueError("The decoder model does not support adapters. Ensure it is a PEFT model.")
        self.decoder.set_adapter("policy")
    
    def change_to_default(self):
        if not hasattr(self.decoder, 'set_adapter'):
            raise ValueError("The decoder model does not support adapters. Ensure it is a PEFT model.")
        self.decoder.disable_adapter()

    @contextmanager
    def disable_adapters(self):
        # self.decoder 是真正的 PeftModel 对象
        peft_model = self.decoder

        # 检查 self.decoder 是否真的是一个 PEFT 模型 (拥有 disable_adapter 方法)
        # 并且检查该方法是否可调用 (即它是一个上下文管理器)
        if hasattr(peft_model, 'disable_adapter') and callable(getattr(peft_model, 'disable_adapter')):
            # 直接使用 peft_model 自带的、最稳妥的上下文管理器
            with peft_model.disable_adapter():
                yield
        else:
            logging.error("The model does not support disabling adapters via a context manager.")
            # 如果 self.decoder 不是一个 PEFT 模型, 则什么都不做
            yield

    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        if not self._supports_gradient_checkpointing: raise ValueError("Gradient checkpointing not supported.")
        self.decoder.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
        self.gradient_checkpointing = True

    def gradient_checkpointing_disable(self):
        if self.supports_gradient_checkpointing: self.decoder.gradient_checkpointing_disable()
        self.gradient_checkpointing = False

    def forward_encoder(self, audios, prompt):
        spatial_embeds = self.spatial_encoder.forward_as_encoder(audios)
        audios_mono = audios.mean(dim=1)
        mels_numpy = audios_mono.to(torch.float32).cpu().numpy()
        mels = self.whisper_feature_extractor(mels_numpy, sampling_rate=16000, return_tensors='pt')['input_features'].to(audios.device)
        whisper_embeds = self.whisper.model.encoder(mels)['last_hidden_state']
        B, F, T, H = spatial_embeds.shape
        spatial_reshaped = spatial_embeds.reshape(B, T, F * H)
        spatial_projected = self.spatial_proj(spatial_reshaped)
        resampled_transposed = torch.nn.functional.interpolate(spatial_projected.transpose(1, 2), size=whisper_embeds.shape[1], mode='linear', align_corners=False)
        spatial_aligned = resampled_transposed.transpose(1, 2)
        combined_embeds = torch.add(spatial_aligned, whisper_embeds)
        audio_context = self.audio_context_extractor(audios_mono)
        prompt_embeds = self.get_prompt_embeds(prompt, audios)
        moe_output, _ = self.moe_layer(whisper_embeds=whisper_embeds, spatial_embeds=spatial_aligned, combined_embeds=combined_embeds, audio_context=audio_context, prompt_embeds=prompt_embeds)
        return moe_output, None

    def forward(self, samples, encoder_hidden_states: Optional[torch.Tensor] = None):
        text = samples["text"]
        task = samples.get("task", ["AAC"] * len(text))
        prompt = [t + " <AcousticTokens>" for t in task]

        if encoder_hidden_states is None:
            if "audios" not in samples:
                raise ValueError("`encoder_hidden_states` was not provided, so `samples` must contain 'audios'.")
            audios = samples["audios"] # <-- 移动到这里
            if self.training and self.gradient_checkpointing:
                encoder_hidden_states, _ = checkpoint.checkpoint(self.forward_encoder, audios, prompt, use_reentrant=True)
            else:
                encoder_hidden_states, _ = self.forward_encoder(audios, prompt)
        encoder_hidden_states = encoder_hidden_states.to(self.decoder.dtype)
        encoder_atts = torch.ones(encoder_hidden_states.size()[:-1], dtype=torch.long).to(encoder_hidden_states.device)
        input_embeds, input_mask, decoder_targets = self.prepare_inputs_labels_for_multimodal(encoder_hidden_states, encoder_atts, prompt, text)
        decoder_output = self.decoder(inputs_embeds=input_embeds, attention_mask=input_mask, labels=decoder_targets, return_dict=True)
        return {"loss": decoder_output.loss, "logits": decoder_output.logits}
    
    def prepare_inputs_labels_for_multimodal(self, audio_embeds, atts, prompt, text=None):
        prompt_left, prompt_right = [], []
        for p in prompt: 
            try:
                l, r = p.split("<AcousticTokens>")
                prompt_left.append(self.tokenizer.bos_token + l); prompt_right.append(r)
            except ValueError:
                # Handle cases where the placeholder is not in the prompt
                prompt_left.append(self.tokenizer.bos_token + p); prompt_right.append("")

        prompt_left_tokens = self.tokenizer(prompt_left, add_special_tokens=False, padding="longest", return_tensors="pt").to(audio_embeds.device)
        prompt_left_embeds = self.decoder.get_input_embeddings()(prompt_left_tokens.input_ids.long())
        
        prompt_right_tokens = self.tokenizer(prompt_right, add_special_tokens=False, padding="longest", return_tensors="pt").to(audio_embeds.device)
        prompt_right_embeds = self.decoder.get_input_embeddings()(prompt_right_tokens.input_ids.long())
        
        input_embeds = torch.cat([prompt_left_embeds, audio_embeds, prompt_right_embeds], dim=1)
        input_mask = torch.cat([prompt_left_tokens.attention_mask, atts, prompt_right_tokens.attention_mask], dim=1)
        
        if text is not None:
            new_text = [t + self.tokenizer.eos_token for t in text]
            text_tokens = self.tokenizer(new_text, add_special_tokens=False, padding="longest", return_tensors="pt").to(audio_embeds.device)
            text_embeds = self.decoder.get_input_embeddings()(text_tokens.input_ids.long())
            targets = text_tokens.input_ids.masked_fill(text_tokens.attention_mask == 0, -100)
            empty_targets = torch.full((input_mask.shape[0], input_mask.shape[1]), -100, dtype=torch.long, device=audio_embeds.device)
            decoder_targets = torch.cat([empty_targets, targets], dim=1)
            input_embeds = torch.cat([input_embeds, text_embeds], dim=1)
            input_mask = torch.cat([input_mask, text_tokens.attention_mask], dim=1)
        else:
            decoder_targets = None
            
        return input_embeds, input_mask, decoder_targets
    
    # def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
    #     """
    #     重写 Trainer 的 save_model 方法。
    #     这个方法会在满足 save_steps, save_strategy 或训练结束时被 Trainer 自动调用。
    #     我们将其行为修改为只保存 LoRA 适配器权重和 tokenizer。
    #     """
    #     # Trainer 的内部逻辑依赖于 is_world_process_zero() 来确保只在主进程上执行IO操作
    #     if self.is_world_process_zero():
    #         # 如果没有在调用时指定 output_dir，则使用训练参数中的默认输出目录
    #         if output_dir is None:
    #             output_dir = self.args.output_dir
            
    #         # 创建目录（如果不存在）
    #         os.makedirs(output_dir, exist_ok=True)
            
    #         print(f"--- [GRPOTrainer] Overridden save_model: Saving LoRA adapter to {output_dir} ---")

    #         # 确保我们的模型有 decoder 属性，并且它是一个 PeftModel
    #         # self.model 是被 Accelerator wrap 过的，所以需要 self.accelerator.unwrap_model
    #         unwrapped_model = self.accelerator.unwrap_model(self.model)

    #         if hasattr(unwrapped_model, 'decoder') and hasattr(unwrapped_model.decoder, 'save_pretrained'):
    #             # 1. 保存 LoRA 适配器权重和配置
    #             unwrapped_model.change_to_policy()
    #             unwrapped_model.save_pretrained(output_dir)
    #         else:
    #             # 备用逻辑或警告，以防模型结构变化
    #             # 在你的情况下，上面的 if 应该会成功
    #             print(f"Warning: Could not find a PEFT model at `model.decoder` to save.")
    #             # 如果找不到，可以退回到保存整个模型，但这会很大
    #             # super().save_model(output_dir, _internal_call=_internal_call)

    #         # 2. 保存 tokenizer 是一个好习惯，它很小，并且让 checkpoint 更完整
    #         if self.tokenizer is not None:
    #             self.tokenizer.save_pretrained(output_dir)

    #         print(f"--- [GRPOTrainer] LoRA adapter and tokenizer saved successfully. ---")

    def get_prompt_embeds(self, prompt, audios):
        prompt_left, prompt_right = [], []
        for p in prompt: 
            l, r = p.split("<AcousticTokens>")
            prompt_left.append(self.tokenizer.bos_token + l); prompt_right.append(r)
        prompt_left_tokens = self.tokenizer(prompt_left, add_special_tokens=False, padding="longest", return_tensors="pt").to(audios.device)
        prompt_left_embeds = self.decoder.get_input_embeddings()(prompt_left_tokens.input_ids.long())
        prompt_right_tokens = self.tokenizer(prompt_right, add_special_tokens=False, padding="longest", return_tensors="pt").to(audios.device)
        prompt_right_embeds = self.decoder.get_input_embeddings()(prompt_right_tokens.input_ids.long())
        return torch.cat([prompt_left_embeds, prompt_right_embeds], dim=1)

    def generate(self, input_ids=None, attention_mask=None, audio=None, encoder_hidden_states: Optional[torch.Tensor] = None, **kwargs):
        if encoder_hidden_states is None:
            if audio is None: raise ValueError("Either 'audio' or 'encoder_hidden_states' must be provided.")
            task = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
            prompt = [t + " <AcousticTokens>" for t in task]
            encoder_hidden_states, _ = self.forward_encoder(audio, prompt)
            
        encoder_hidden_states = encoder_hidden_states.to(self.decoder.dtype)
        encoder_atts = torch.ones(encoder_hidden_states.size()[:-1], dtype=torch.long).to(encoder_hidden_states.device)
        
        prompt_for_embeds = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
        prompt_for_embeds = [p + " <AcousticTokens>" for p in prompt_for_embeds]

        input_embeds, input_mask, _ = self.prepare_inputs_labels_for_multimodal(
            encoder_hidden_states, encoder_atts, prompt_for_embeds
        )

        # print(input_embeds.shape, input_mask.shape)
        
        # --- 核心修复：移除不被接受的 'prompt' 参数 ---
        kwargs.pop("prompt", None)
        
        outputs = self.decoder.generate(
            inputs_embeds=input_embeds, 
            attention_mask=input_mask, 
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
            **kwargs
        )
        return outputs
