#!/usr/bin/env python3
"""
加载训练好的prompt decoder
"""

import torch
import torch.nn as nn
from pathlib import Path
from decoder_train import ResponseToPromptDecoder, DecoderConfig

def load_trained_decoder(decoder_path: str = "prompt_decoder/prompt_decoder.pt"):
    """加载训练好的decoder
    
    Args:
        decoder_path: decoder模型路径
        
    Returns:
        decoder: 训练好的decoder模型
        config: 模型配置
    """
    decoder_path = Path(decoder_path)
    
    if not decoder_path.exists():
        raise FileNotFoundError(f"Decoder not found: {decoder_path}")
    
    # 加载模型
    checkpoint = torch.load(decoder_path, map_location='cpu')
    
    # 获取配置
    config = checkpoint['config']
    
    # 创建decoder
    decoder = ResponseToPromptDecoder(
        input_dim=config.response_dim if hasattr(config, 'response_dim') else 1024,
        output_dim=config.prompt_dim if hasattr(config, 'prompt_dim') else 4096
    )
    
    # 加载权重
    decoder.load_state_dict(checkpoint['decoder_state_dict'])
    
    print(f"✅ Loaded decoder from {decoder_path}")
    print(f"📊 Final loss: {checkpoint.get('final_loss', 'N/A')}")
    
    return decoder, config

def load_decoder_weights_only(weights_path: str = "prompt_decoder/decoder_weights.pt"):
    """只加载decoder权重（需要手动指定维度）"""
    weights_path = Path(weights_path)
    
    if not weights_path.exists():
        raise FileNotFoundError(f"Weights not found: {weights_path}")
    
    # 手动指定维度（需要根据你的SAE和模型调整）
    input_dim = 1024  # SAE输出维度
    output_dim = 4096  # 模型hidden size
    
    decoder = ResponseToPromptDecoder(input_dim=input_dim, output_dim=output_dim)
    decoder.load_state_dict(torch.load(weights_path, map_location='cpu'))
    
    print(f"✅ Loaded decoder weights from {weights_path}")
    return decoder

if __name__ == "__main__":
    try:
        decoder, config = load_trained_decoder()
        print(f"📊 Decoder loaded successfully!")
        print(f"📊 Input dimension: {decoder.decoder[0].in_features}")
        print(f"📊 Output dimension: {decoder.decoder[-1].out_features}")
    except Exception as e:
        print(f"❌ Failed to load decoder: {e}")
        print("Trying to load weights only...")
        try:
            decoder = load_decoder_weights_only()
            print("✅ Loaded decoder weights successfully!")
        except Exception as e2:
            print(f"❌ Failed to load weights: {e2}") 