import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DistributedSampler
from pathlib import Path
from typing import Dict, List, Tuple
from dataclasses import dataclass
import argparse
from tqdm import tqdm
import yaml
import wandb
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
import matplotlib.pyplot as plt
import numpy as np
import math
from collections import defaultdict

logger = get_logger(__name__)

@dataclass
class DecoderConfig:
    # 数据配置
    vectorized_data_dir: str = "datasets/vectorized_sae_llama3b_layers_10"
    train_ratio: float = 0.8
    val_ratio: float = 0.1
    test_ratio: float = 0.1
    
    # 模型配置
    sae_path: str = "sae/sae_llama3b_layers_10.pth"
    fine_tune: bool = False
    normalize: bool = False
    hidden_layer: int = 0  # 如果>0，使用MLP结构: (4096*32)->hidden_layer->4096
    
    # 训练配置
    batch_size: int = 32
    learning_rate: float = 3e-4
    num_epochs: int = 100
    weight_decay: float = 1e-5
    warmup_steps: int = 1000
    max_grad_norm: float = 1.0
    
    # 精度配置
    mixed_precision: str = "no"  # "fp16", "bf16", "no" - 默认使用 fp32
    
    # 数据加载器配置
    dataloader_pin_memory: bool = True
    dataloader_num_workers: int = 1
    
    # 输出配置
    output_dir: str = "prompt_decoder"
    intermediate_eval_frequency: int = 50  # 每多少个batch进行一次中间验证
    
    # 损失函数配置
    loss_type: str = "mse"  # "mse" 或 "cosine"
    
    # GPU配置
    gpu_ids: str = "0"  # GPU ID，如 "0", "1", "0,1"
    
    # 测试配置 - 已移除fixed_test参数，现在总是运行两种测试
    
    # Wandb配置
    use_wandb: bool = False
    wandb_project: str = "prompt-decoder"
    wandb_run_name: str = "decoder-training"

class PromptDecoder(nn.Module):
    def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16, normalize: bool = False, hidden_layer: int = 0):
        super().__init__()
        self.hidden_layer = hidden_layer
        self.device = device
        self.dtype = dtype
        
        # 直接在这里修改初始化参数
        self.init_mean = 0.0  # 修改这里的均值
        self.init_std = 0.0025  # 修改这里的标准差
        
        if hidden_layer > 0:
            # MLP结构: (4096*32) -> hidden_layer -> 4096
            input_dim = 4096 * 32  # SAE representation dimension
            self.mlp = nn.Sequential(
                nn.Linear(input_dim, hidden_layer),
                nn.LayerNorm(hidden_layer),
                nn.GELU(),
                nn.Linear(hidden_layer, 4096),
                nn.LayerNorm(4096)
            )
            
            # 使用init_config初始化MLP权重
            init_config = [
                (131072, hidden_layer, 0.01007843, 0.00969696),  # Layer 1: input -> hidden_layer
                (hidden_layer, 4096, 0.00969696, 0.00969696),  # Layer 2: hidden_layer -> 4096
            ]
            
            with torch.no_grad():
                i = 0
                for layer in self.mlp:
                    if isinstance(layer, nn.Linear):
                        fan_in, fan_out, var_in, var_out = init_config[i]
                        std = math.sqrt(var_out / (fan_in * var_in))
                        nn.init.normal_(layer.weight, mean=0.0, std=std)
                        nn.init.zeros_(layer.bias)
                        print(f"Initialized Linear({fan_in}, {fan_out}) with std = {std:.6f}")
                        i += 1
            
            logger.info(f"✅ Created MLP structure: {input_dim} -> {hidden_layer} -> 4096 with LayerNorm")
            logger.info(f"MLP initialized with mean={self.init_mean}, std={self.init_std}")
        else:
            # 原始线性层结构 - 使用自定义参数随机初始化
            input_dim = 4096 * 32  # SAE representation dimension
            output_dim = 4096  # prompt embedding dimension
            
            # 使用init_config初始化线性层参数
            fan_in, fan_out = input_dim, output_dim
            var_in, var_out = 0.01007843, 0.00969696
            std = math.sqrt(var_out / (fan_in * var_in))
            
            self.decoder_weight = nn.Parameter(torch.randn(output_dim, input_dim, device=device, dtype=dtype) * std)
            self.decoder_bias = nn.Parameter(torch.zeros(output_dim, device=device, dtype=dtype))
            
            print(f"Initialized Linear({fan_in}, {fan_out}) with std = {std:.6f}")
            
            # 添加LayerNorm
            self.layer_norm = nn.LayerNorm(output_dim)
            
            logger.info(f"✅ Created linear structure: {input_dim} -> {output_dim} with LayerNorm and custom initialization")
            logger.info(f"decoder_weight: mean = {self.decoder_weight.data.mean().item():.6f}, std = {self.decoder_weight.data.std().item():.6f}")
            logger.info(f"decoder_bias: mean = {self.decoder_bias.data.mean().item():.6f}, std = {self.decoder_bias.data.std().item():.6f}")

    def forward(self, x):
        # NaN + inf check
        if torch.isnan(x).any() or torch.isinf(x).any():
            logger.warning("Input contains NaN or Inf in forward pass")
            x = torch.nan_to_num(x, nan=0.0, posinf=1e4, neginf=-1e4)

        # 强制转换为 fp32
        x = x.float()
        
        if self.hidden_layer > 0:
            # 使用MLP结构（已包含LayerNorm）
            for layer in self.mlp:
                layer = layer.float()
            output = self.mlp(x)
        else:
            # 使用原始线性层结构 + LayerNorm
            self.decoder_weight.data = self.decoder_weight.data.float()
            self.decoder_bias.data = self.decoder_bias.data.float()
            self.layer_norm = self.layer_norm.float()
            
            # 线性层 + LayerNorm
            output = F.linear(x, self.decoder_weight, self.decoder_bias)
            output = self.layer_norm(output)

        if torch.isnan(output).any() or torch.isinf(output).any():
            logger.warning("Output contains NaN or Inf in forward pass")
            return torch.zeros_like(output)

        return output

class VectorizedDataset(Dataset):
    """加载向量化数据的Dataset"""
    
    def __init__(self, response_sae_repr: torch.Tensor, prompt_embeddings: torch.Tensor):
        self.response_sae_repr = response_sae_repr
        self.prompt_embeddings = prompt_embeddings
        
        assert len(response_sae_repr) == len(prompt_embeddings), \
            f"Data length mismatch: {len(response_sae_repr)} vs {len(prompt_embeddings)}"
    
    def __len__(self):
        return len(self.response_sae_repr)
    
    def __getitem__(self, idx):
        return {
            'response_sae': self.response_sae_repr[idx],
            'prompt_embedding': self.prompt_embeddings[idx]
        }

class AccelerateTrainer:
    """使用Accelerate的训练器"""
    
    def __init__(self, config: DecoderConfig):
        self.config = config
        
        # 设置GPU环境变量（在Accelerate初始化之前）
        if hasattr(config, 'gpu_ids') and config.gpu_ids:
            os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_ids
        
        # 设置wandb run_name环境变量
        if config.use_wandb:
            os.environ["WANDB_RUN_NAME"] = config.wandb_run_name
        
        # 初始化Accelerate
        self.accelerator = Accelerator(
            mixed_precision="no",  # 强制使用 fp32
            gradient_accumulation_steps=1,
            log_with="wandb" if config.use_wandb else None,
            project_dir=config.output_dir if config.use_wandb else None,
        )
        
        # 设置设备
        self.device = self.accelerator.device
        
        # 初始化模型
        self.model = PromptDecoder(
            device=str(self.device),
            dtype=torch.float32,  # 使用 fp32
            normalize=config.normalize,  # 根据配置决定是否normalize
            hidden_layer=config.hidden_layer  # 传递hidden_layer参数
        )
        
        # 损失函数和优化器
        if config.loss_type == "mse":
            self.criterion = nn.MSELoss()
        elif config.loss_type == "cosine":
            self.criterion = self.cosine_loss
        else:
            raise ValueError(f"Unknown loss_type: {config.loss_type}. Must be 'mse' or 'cosine'")
        
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay,
        )
        
        # 学习率调度器
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, 
            T_max=config.num_epochs
        )
        
        # 打印配置信息（只在主进程中）
        if self.accelerator.is_main_process:
            print(f"🔧 Auto-generated run_name: {config.wandb_run_name}")
            print(f"🔧 Using GPUs: {config.gpu_ids}")
            print(f"🔧 Number of processes: {self.accelerator.num_processes}")
            print(f"🔧 Mixed precision: {config.mixed_precision}")
            print(f"🔧 Fine-tune: {config.fine_tune}")
            print(f"🔧 Normalize: {config.normalize}")
            print(f"🔧 Loss type: {config.loss_type}")
            print(f"🔧 Batch size: {config.batch_size}")
            print(f"🔧 Learning rate: {config.learning_rate}")
            print(f"🔧 Num epochs: {config.num_epochs}")
        
        # 初始化wandb
        if config.use_wandb and self.accelerator.is_main_process:
            self.accelerator.init_trackers(
                project_name=config.wandb_project,
                config=vars(config)
            )
    
    def cosine_loss(self, predicted, target):
        """计算cosine similarity loss (1 - cosine_similarity)"""
        # 归一化向量
        predicted_norm = F.normalize(predicted, p=2, dim=1)
        target_norm = F.normalize(target, p=2, dim=1)
        
        # 计算cosine similarity
        cosine_sim = F.cosine_similarity(predicted_norm, target_norm, dim=1)
        
        # 返回1 - cosine_similarity作为损失（越小越好）
        return torch.mean(1 - cosine_sim)
    
    def load_data(self):
        data_dir = Path(self.config.vectorized_data_dir)
        
        if self.accelerator.is_main_process:
            print(f"📁 Loading data from: {data_dir}")

        # 检查数据目录是否存在
        if not data_dir.exists():
            raise FileNotFoundError(f"Data directory not found: {data_dir}")
        
        # 检查数据文件是否存在
        response_file = data_dir / "response_sae_repr.pt"
        prompt_file = data_dir / "prompt_embeddings.pt"
        
        if not response_file.exists():
            raise FileNotFoundError(f"Response file not found: {response_file}")
        if not prompt_file.exists():
            raise FileNotFoundError(f"Prompt file not found: {prompt_file}")
        
        # 加载数据
        if self.accelerator.is_main_process:
            print(f"🔄 Loading {response_file}...")
        try:
            response_sae_repr = torch.load(response_file, map_location="cpu")
        except Exception as e:
            raise RuntimeError(f"Failed to load response_sae_repr.pt: {e}")
        
        if self.accelerator.is_main_process:
            print(f"🔄 Loading {prompt_file}...")
        try:
            prompt_embeddings = torch.load(prompt_file, map_location="cpu")
        except Exception as e:
            raise RuntimeError(f"Failed to load prompt_embeddings.pt: {e}")

        # 加载 metadata
        with open(data_dir / "metadata.json", "r") as f:
            metadata = json.load(f)

        logger.info(f"Data loaded. Shape: response {response_sae_repr.shape}, prompt {prompt_embeddings.shape}")

        # 模型已经随机初始化，这里只需要设置requires_grad
        for param in self.model.parameters():
            param.requires_grad = True
        logger.info("✅ Model initialized with random weights")
        
        # 如果fine_tune=True，用SAE参数覆盖随机初始化的参数
        if self.config.fine_tune:
            logger.info("🔄 Fine-tuning mode: Loading SAE parameters to override random initialization")
            try:
                state_dict = torch.load(self.config.sae_path, map_location="cpu")
                
                if self.config.hidden_layer > 0:
                    # MLP结构 - 需要将SAE的线性层参数适配到MLP
                    logger.warning("⚠️  Fine-tuning with MLP structure: SAE parameters may not match MLP architecture")
                else:
                    # 线性层结构 - 直接加载SAE参数
                    with torch.no_grad():
                        self.model.decoder_weight.data = state_dict['encoder.weight'].T.to(self.device, dtype=torch.float32)
                        self.model.decoder_bias.data = state_dict['b_dec'].to(self.device, dtype=torch.float32)
                    
                    # 如果需要normalize
                    if self.config.normalize:
                        with torch.no_grad():
                            eps = 1e-10
                            self.model.decoder_weight.data = F.normalize(self.model.decoder_weight.data + eps, dim=0)
                        logger.info("✅ Applied normalization to loaded SAE parameters")
                    
                    logger.info("✅ Successfully loaded SAE parameters for fine-tuning")
                    
            except Exception as e:
                logger.error(f"Failed to load SAE parameters for fine-tuning: {e}")
                raise
        
        # 创建优化器
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay
        )

        # 划分数据
        total_samples = len(response_sae_repr)
        train_size = int(total_samples * self.config.train_ratio)
        val_size = int(total_samples * self.config.val_ratio)
        test_size = total_samples - train_size - val_size

        train_indices, temp_indices = train_test_split(range(total_samples), train_size=train_size, random_state=42)
        val_indices, test_indices = train_test_split(temp_indices, train_size=val_size, random_state=42)

        # 直接使用tensor索引，避免转换为list
        self.train_dataset = VectorizedDataset(
            response_sae_repr[train_indices],
            prompt_embeddings[train_indices]
        )
        self.val_dataset = VectorizedDataset(
            response_sae_repr[val_indices],
            prompt_embeddings[val_indices]
        )
        self.test_dataset = VectorizedDataset(
            response_sae_repr[test_indices],
            prompt_embeddings[test_indices]
        )

        # 💡 DistributedSampler：确保每个 GPU 拿到不同数据
        train_sampler = DistributedSampler(
            self.train_dataset,
            num_replicas=self.accelerator.num_processes,
            rank=self.accelerator.process_index,
            shuffle=True
        )

        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.config.batch_size,
            sampler=train_sampler,
            pin_memory=self.config.dataloader_pin_memory,
            num_workers=self.config.dataloader_num_workers,
            persistent_workers=False
        )

        self.val_loader = DataLoader(
            self.val_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            pin_memory=self.config.dataloader_pin_memory,
            num_workers=self.config.dataloader_num_workers
        )

        self.test_loader = DataLoader(
            self.test_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            pin_memory=self.config.dataloader_pin_memory,
            num_workers=self.config.dataloader_num_workers
        )

        logger.info(f"Dataset split: train={len(self.train_dataset)}, val={len(self.val_dataset)}, test={len(self.test_dataset)}")

        # Accelerate 包装
        self.model, self.optimizer, self.train_loader, self.val_loader, self.test_loader = self.accelerator.prepare(
            self.model, self.optimizer, self.train_loader, self.val_loader, self.test_loader
        )

        if self.accelerator.is_main_process:
            print(f"🔧 Process {self.accelerator.process_index} on device {self.accelerator.device}")
            print(f"🔧 Batch size: {self.config.batch_size}")
            print(f"🔧 Train batches: {len(self.train_loader)}")
            
    def train_epoch(self, epoch: int):
        """训练一个epoch"""
        self.model.train()
        
        progress_bar = tqdm(
            self.train_loader, 
            desc=f"Epoch {epoch + 1}",
            disable=not self.accelerator.is_main_process
        )
        
        for batch_idx, batch in enumerate(progress_bar):
            response_sae = batch['response_sae'].float()  # 强制转换为 fp32
            prompt_embedding = batch['prompt_embedding'].float()  # 强制转换为 fp32
            
            # 前向传播
            predicted_prompt = self.model(response_sae)
            
            # 检查输出是否包含NaN
            if torch.isnan(predicted_prompt).any():
                logger.warning(f"NaN detected in model output at batch {batch_idx}")
                continue
                
            loss = self.criterion(predicted_prompt, prompt_embedding)
            
            # 检查损失是否为NaN
            if torch.isnan(loss):
                logger.warning(f"NaN loss detected at batch {batch_idx}")
                continue
            
            # 反向传播
            self.accelerator.backward(loss)
            
            self.optimizer.step()
            self.optimizer.zero_grad()
            progress_bar.set_postfix({
                'loss': f"{loss.item():.6f}"
            })
            
            # 记录到wandb - 只记录当前batch的loss
            if self.config.use_wandb and self.accelerator.is_main_process:
                self.accelerator.log({
                    'train_loss': loss.item()
                })
            
            # 每指定数量的batch进行一次中间验证
            if (batch_idx + 1) % self.config.intermediate_eval_frequency == 0:
                if self.accelerator.is_main_process:
                    print(f"\n🔄 Epoch {epoch + 1}, Batch {batch_idx + 1}: Running intermediate validation...")
                
                # 临时切换到评估模式进行验证
                self.model.eval()
                intermediate_val_loss = 0
                intermediate_batches = 0
                
                with torch.no_grad():
                    # 只验证前几个batch来快速评估
                    for i, val_batch in enumerate(self.val_loader):
                        if i >= 10:  # 只验证前10个batch来节省时间
                            break
                        val_response_sae = val_batch['response_sae'].float()
                        val_prompt_embedding = val_batch['prompt_embedding'].float()
                        
                        val_predicted_prompt = self.model(val_response_sae)
                        val_loss = self.criterion(val_predicted_prompt, val_prompt_embedding)
                        
                        intermediate_val_loss += val_loss.item()
                        intermediate_batches += 1
                
                avg_intermediate_val_loss = intermediate_val_loss / intermediate_batches
                
                if self.accelerator.is_main_process:
                    print(f"📊 Intermediate Val Loss: {avg_intermediate_val_loss:.6f}")
                
                # 记录到wandb - 使用不同的metric名称避免冲突
                if self.config.use_wandb and self.accelerator.is_main_process:
                    self.accelerator.log({
                        'intermediate_val_loss': avg_intermediate_val_loss,
                        'step': batch_idx + 1  # 添加step信息
                    })
                
                # 切换回训练模式
                self.model.train()
        
        # 由于我们不再计算平均loss，返回None或者最后一个batch的loss
        return None
    
    def evaluate(self, dataloader, split_name: str):
        """评估模型"""
        self.model.eval()
        total_loss = 0
        num_batches = 0
        
        with torch.no_grad():
            for batch in dataloader:
                response_sae = batch['response_sae'].float()  # 强制转换为 fp32
                prompt_embedding = batch['prompt_embedding'].float()  # 强制转换为 fp32
                
                predicted_prompt = self.model(response_sae)
                loss = self.criterion(predicted_prompt, prompt_embedding)
                
                total_loss += loss.item()
                num_batches += 1
        
        avg_loss = total_loss / num_batches
        
        if self.config.use_wandb and self.accelerator.is_main_process:
            self.accelerator.log({
                f'{split_name}_loss': avg_loss
            })
        
        return avg_loss
    
    def save_checkpoint(self, epoch: int, val_loss: float, is_final: bool = False):
        """保存检查点"""
        if not self.accelerator.is_main_process:
            return
        
        os.makedirs(self.config.output_dir, exist_ok=True)
        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.accelerator.unwrap_model(self.model).state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'val_loss': val_loss,
            'config': vars(self.config)
        }
        
        # 生成包含hidden_layer和loss_type信息的文件名
        hidden_str = f"h{self.config.hidden_layer}" if self.config.hidden_layer > 0 else "linear"
        loss_str = self.config.loss_type  # "mse" 或 "cosine"
        
        # 只在训练结束时保存最终模型
        if is_final:
            final_path = os.path.join(self.config.output_dir, f'final_model_{hidden_str}_{loss_str}.pt')
            torch.save(checkpoint, final_path)
            logger.info(f"Saved final model at epoch {epoch} with val_loss: {val_loss:.6f}")
        
        # 保存最新检查点（用于cross-prediction测试）
        latest_path = os.path.join(self.config.output_dir, f'latest_checkpoint_{hidden_str}_{loss_str}.pt')
        torch.save(checkpoint, latest_path)
        logger.info(f"Saved latest checkpoint at epoch {epoch}")
    
    def _cleanup_latest_checkpoint(self):
        """删除latest checkpoint文件以节省磁盘空间"""
        if not self.accelerator.is_main_process:
            return
        
        # 生成包含hidden_layer和loss_type信息的文件名
        hidden_str = f"h{self.config.hidden_layer}" if self.config.hidden_layer > 0 else "linear"
        loss_str = self.config.loss_type  # "mse" 或 "cosine"
        latest_path = os.path.join(self.config.output_dir, f'latest_checkpoint_{hidden_str}_{loss_str}.pt')
        
        try:
            if os.path.exists(latest_path):
                os.remove(latest_path)
                logger.info(f"🗑️  Deleted latest checkpoint: {latest_path}")
            else:
                logger.debug(f"Latest checkpoint not found: {latest_path}")
        except Exception as e:
            logger.warning(f"Failed to delete latest checkpoint {latest_path}: {e}")
    
    def _save_epoch_results(self, epoch: int, test_results: dict):
        """保存每个epoch的测试结果到JSON文件"""
        if not self.accelerator.is_main_process:
            return
        
        # 创建结果目录
        results_dir = os.path.join(self.config.output_dir, "training_results")
        os.makedirs(results_dir, exist_ok=True)
        
        # 生成文件名
        # 从vectorized_data_dir提取training dataset名字
        dataset_name = os.path.basename(self.config.vectorized_data_dir)
        # 移除常见的后缀
        dataset_name = dataset_name.replace("_sae_llama3b_layers_14", "").replace("_sae_llama3b_layers_10", "")
        # 移除vectorized前缀（包括vectorized_last_）
        if dataset_name.startswith("vectorized_last_"):
            dataset_name = dataset_name.replace("vectorized_last_", "")
        elif dataset_name.startswith("vectorized_"):
            dataset_name = dataset_name.replace("vectorized_", "")
        
        run_name = dataset_name
        
        results_file = os.path.join(results_dir, f"training_results_{run_name}.json")
        
        # 准备要保存的数据 - 只保存rewritten的结果
        import time
        epoch_data = {
            'epoch': epoch,
            'timestamp': time.time(),
            'rewritten_results': test_results.get('rewritten_results', {})
        }
        
        # 读取现有数据或创建新文件
        if os.path.exists(results_file):
            try:
                with open(results_file, 'r') as f:
                    all_results = json.load(f)
            except (json.JSONDecodeError, FileNotFoundError):
                all_results = {'epochs': [], 'run_name': run_name}
        else:
            all_results = {'epochs': [], 'run_name': run_name}
        
        # 添加当前epoch的结果
        all_results['epochs'].append(epoch_data)
        
        # 保存到文件
        try:
            with open(results_file, 'w') as f:
                json.dump(all_results, f, indent=2)
            logger.info(f"💾 Saved epoch {epoch} results to {results_file}")
        except Exception as e:
            logger.error(f"Failed to save epoch results: {e}")
    
    def train(self):
        """主训练循环"""
        logger.info("Starting training...")
        
        # 训练开始前先进行一次初始验证（随机初始化/或fine-tune加载之后）
        if self.accelerator.is_main_process:
            logger.info("Running initial evaluation before training...")
        initial_val_loss = self.evaluate(self.val_loader, 'val')
        if self.accelerator.is_main_process:
            print(f"Initial Val Loss: {initial_val_loss:.6f}")
        if self.config.use_wandb and self.accelerator.is_main_process:
            self.accelerator.log({
                'initial_val_loss': initial_val_loss,
                'epoch': 0
            })
        
        # 训练开始前运行一次 cross-prediction，并保存为 epoch 0（包含 rewritten 测试）
        if self.accelerator.is_main_process:
            logger.info("🔬 Running initial cross-prediction test for epoch 0...")
            initial_test_results = self.test_cross_prediction(num_tests=500)
            self._save_epoch_results(0, initial_test_results)
        
        for epoch in range(self.config.num_epochs):
            self.train_epoch(epoch)
            val_loss = self.evaluate(self.val_loader, 'val')
            
            # 打印结果 - 只显示验证集 loss
            if self.accelerator.is_main_process:
                print(f"Epoch {epoch + 1}: Val Loss: {val_loss:.6f}")
            
            # 保存检查点 - 只在最后一个epoch保存最终模型
            is_final = (epoch == self.config.num_epochs - 1)
            
            # 将训练中的epoch编号改为从1开始保存
            self.save_checkpoint(epoch + 1, val_loss, is_final)
            self.scheduler.step()
            
            # 每个epoch结束时运行cross-prediction测试
            if self.accelerator.is_main_process:
                logger.info(f"🔬 Running cross-prediction test for epoch {epoch + 1}...")
                test_results = self.test_cross_prediction(num_tests=500)
                
                # 保存测试结果到JSON文件
                # 训练前的初始验证记为epoch 0，这里从1开始
                self._save_epoch_results(epoch + 1, test_results)
                
                # 测试完成后删除latest checkpoint以节省磁盘空间
                self._cleanup_latest_checkpoint()
            
            # 记录到wandb - 只记录验证loss和学习率
            if self.config.use_wandb and self.accelerator.is_main_process:
                self.accelerator.log({
                    'val_loss': val_loss,
                    'learning_rate': self.optimizer.param_groups[0]['lr'],
                    'epoch': epoch + 1  # 训练阶段从1开始
                })
        
        # 最终测试
        if self.test_loader is not None:
            if self.accelerator.is_main_process:
                logger.info("Running final test...")
            test_loss = self.evaluate(self.test_loader, 'test')
            if self.accelerator.is_main_process:
                logger.info(f"Final Test Loss: {test_loss:.6f}")
                
                if self.config.use_wandb:
                    self.accelerator.log({'final_test_loss': test_loss})
        else:
            if self.accelerator.is_main_process:
                logger.info("No test set available")
        
        if self.config.use_wandb and self.accelerator.is_main_process:
            self.accelerator.end_training()


    def test_cross_prediction(self, num_tests: int = 500):
        """测试交叉预测能力 - 同时运行固定子集、随机测试和rewritten测试"""
        if not self.accelerator.is_main_process:
            return
        
        logger.info(f"🔬 Running cross-prediction tests for epoch using {self.config.loss_type} loss...")
        
        # 1. 运行固定子集测试（test8或subset500）
        fixed_results = self._test_cross_prediction_fixed()
        
        # 2. 运行随机测试
        random_results = self._test_cross_prediction_random(num_tests=500)
        
        # 3. 运行rewritten测试
        rewritten_results = self._test_cross_prediction_rewritten()
        
        # 4. 总结三种测试的结果
        print(f"\n📊 Cross-Prediction Summary for Epoch:")
        print(f"  Fixed subset: {fixed_results['correct_predictions']}/{fixed_results['total_individual_tests']} ({fixed_results['accuracy']:.2%})")
        print(f"  Random test: {random_results['correct_predictions']}/{random_results['total_individual_tests']} ({random_results['accuracy']:.2%})")
        if rewritten_results['test_subset_used']:
            print(f"  Rewritten test datasets: 6 datasets tested")
        else:
            print(f"  Rewritten test datasets: Not available")
        
        return {
            'fixed_results': fixed_results,
            'random_results': random_results,
            'rewritten_results': rewritten_results
        }
    
    def _test_cross_prediction_fixed(self):
        """使用固定子集进行交叉预测测试"""
        logger.info(f"🔬 Testing cross-prediction with fixed test subset using {self.config.loss_type} loss...")
        
        # 尝试加载test8固定子集，如果不存在则使用subset500
        test_subset_file = Path("datasets/vectorized_last_test8_split_sae_llama3b_layers_14") / "test_subset_last_fixed.pt"
        if not test_subset_file.exists():
            # 回退到默认的subset500子集
            test_subset_file = Path(self.config.vectorized_data_dir) / "test_subset_subset500.pt"
            logger.info(f"Using default subset500: {test_subset_file}")
        else:
            logger.info(f"Using fixed test8 subset: {test_subset_file}")
        
        if not test_subset_file.exists():
            # 如果固定子集不存在，返回空结果
            logger.warning(f"Fixed test subset not found: {test_subset_file}")
            return {
                'correct_predictions': 0,
                'total_tests': 0,
                'total_individual_tests': 0,
                'accuracy': 0.0,
                'test_subset_used': False
            }
        
        # 加载固定测试子集
        try:
            subset_data = torch.load(test_subset_file, map_location="cpu")
            logger.info(f"✅ Loaded fixed test subset with {subset_data['num_pairs']} pairs")
        except Exception as e:
            logger.error(f"Failed to load test subset: {e}")
            return {
                'correct_predictions': 0,
                'total_tests': 0,
                'total_individual_tests': 0,
                'accuracy': 0.0,
                'test_subset_used': False
            }
        
        results = {
            'correct_predictions': 0,
            'total_tests': len(subset_data['pairs']),
            'total_individual_tests': len(subset_data['pairs']) * 2,  # 每个pair有2个individual tests
            'detailed_results': [],
            'test_subset_used': True,
            'subset_seed': subset_data.get('seed', 'unknown')
        }
        
        self.model.eval()
        
        for test_idx, pair in enumerate(subset_data['pairs']):
            # 使用固定子集中的pairs
            start_idx = pair['start_index']
            next_idx = pair['next_index']
            
            # 获取数据
            response1 = pair['start_response'].unsqueeze(0).float().to(self.device)
            response2 = pair['next_response'].unsqueeze(0).float().to(self.device)
            prompt1 = pair['start_prompt'].unsqueeze(0).float().to(self.device)
            prompt2 = pair['next_prompt'].unsqueeze(0).float().to(self.device)
            
            # 进行预测 - 从response预测prompt
            with torch.no_grad():
                pred_prompt1_from_response1 = self.model(response1)  # response1 -> prompt1
                pred_prompt1_from_response2 = self.model(response2)  # response2 -> prompt1
                pred_prompt2_from_response1 = self.model(response1)  # response1 -> prompt2
                pred_prompt2_from_response2 = self.model(response2)  # response2 -> prompt2
            
            # 根据loss类型计算损失
            if self.config.loss_type == "mse":
                # 使用MSE loss
                loss1_1 = torch.nn.functional.mse_loss(pred_prompt1_from_response1, prompt1).item()  # response1 -> prompt1
                loss1_2 = torch.nn.functional.mse_loss(pred_prompt1_from_response2, prompt1).item()  # response2 -> prompt1
                loss2_1 = torch.nn.functional.mse_loss(pred_prompt2_from_response1, prompt2).item()  # response1 -> prompt2
                loss2_2 = torch.nn.functional.mse_loss(pred_prompt2_from_response2, prompt2).item()  # response2 -> prompt2
            elif self.config.loss_type == "cosine":
                # 使用cosine similarity loss (1 - cosine_similarity)
                loss1_1 = self.cosine_loss(pred_prompt1_from_response1, prompt1).item()  # response1 -> prompt1
                loss1_2 = self.cosine_loss(pred_prompt1_from_response2, prompt1).item()  # response2 -> prompt1
                loss2_1 = self.cosine_loss(pred_prompt2_from_response1, prompt2).item()  # response1 -> prompt2
                loss2_2 = self.cosine_loss(pred_prompt2_from_response2, prompt2).item()  # response2 -> prompt2
            else:
                raise ValueError(f"Unknown loss_type: {self.config.loss_type}")
            
            # 判断是否正确预测
            test1_correct = loss1_1 < loss1_2  # response1应该比response2更好地预测prompt1
            test2_correct = loss2_2 < loss2_1  # response2应该比response1更好地预测prompt2
            
            if test1_correct:
                results['correct_predictions'] += 1
            if test2_correct:
                results['correct_predictions'] += 1
            
            # 记录详细结果
            test_result = {
                'test_idx': test_idx,
                'sample_indices': [start_idx, next_idx],
                'loss1_1': float(loss1_1),  # response1 -> prompt1
                'loss1_2': float(loss1_2),  # response2 -> prompt1
                'loss2_1': float(loss2_1),  # response1 -> prompt2
                'loss2_2': float(loss2_2),  # response2 -> prompt2
                'test1_correct': bool(test1_correct),
                'test2_correct': bool(test2_correct),
                'margin1': float(loss1_2 - loss1_1),  # 正数表示test1正确
                'margin2': float(loss2_1 - loss2_2)   # 正数表示test2正确
            }
            
            results['detailed_results'].append(test_result)
            
            # 只显示前8条结果
            if test_idx < 8:
                print(f"Test {test_idx + 1}: Sample indices: {start_idx}, {next_idx}")
                print(f"    Test1 correct: {'✅' if test1_correct else '❌'} and Test2 correct: {'✅' if test2_correct else '❌'}")
                print(f"  Margin1: {test_result['margin1']:.6f}, Margin2: {test_result['margin2']:.6f}")
                print()
        
        # 计算准确率
        accuracy = results['correct_predictions'] / results['total_individual_tests']
        results['accuracy'] = accuracy
        
        print(f"📊 Fixed Subset Results: {results['correct_predictions']}/{results['total_individual_tests']} ({accuracy:.2%})")
        
        self.model.train()
        return results
    
    def _test_cross_prediction_random(self, num_tests: int = 500):
        """随机选择的交叉预测测试 - 不打印详细信息"""
        logger.info(f"🔬 Testing cross-prediction with {num_tests} random pairs using {self.config.loss_type} loss...")
        
        response_sae_repr = torch.load(Path(self.config.vectorized_data_dir) / "response_sae_repr.pt")
        prompt_embeddings = torch.load(Path(self.config.vectorized_data_dir) / "prompt_embeddings.pt")
        results = {
            'correct_predictions': 0,
            'total_tests': num_tests,
            'total_individual_tests': num_tests * 2,  # 每个pair有2个individual tests
            'detailed_results': [],
            'test_subset_used': False
        }
        
        self.model.eval()
        
        for test_idx in range(num_tests):
            # 随机选择两个不同的样本
            indices = np.random.choice(len(response_sae_repr), 2, replace=False)
            idx1, idx2 = indices[0], indices[1]
            
            # 获取数据
            response1 = response_sae_repr[idx1:idx1+1].float().to(self.device)
            response2 = response_sae_repr[idx2:idx2+1].float().to(self.device)
            prompt1 = prompt_embeddings[idx1:idx1+1].float().to(self.device)
            prompt2 = prompt_embeddings[idx2:idx2+1].float().to(self.device)
            
            # 进行预测 - 从response预测prompt
            with torch.no_grad():
                pred_prompt1_from_response1 = self.model(response1)  # response1 -> prompt1
                pred_prompt1_from_response2 = self.model(response2)  # response2 -> prompt1
                pred_prompt2_from_response1 = self.model(response1)  # response1 -> prompt2
                pred_prompt2_from_response2 = self.model(response2)  # response2 -> prompt2
            
            # 根据loss类型计算损失
            if self.config.loss_type == "mse":
                # 使用MSE loss
                loss1_1 = torch.nn.functional.mse_loss(pred_prompt1_from_response1, prompt1).item()  # response1 -> prompt1
                loss1_2 = torch.nn.functional.mse_loss(pred_prompt1_from_response2, prompt1).item()  # response2 -> prompt1
                loss2_1 = torch.nn.functional.mse_loss(pred_prompt2_from_response1, prompt2).item()  # response1 -> prompt2
                loss2_2 = torch.nn.functional.mse_loss(pred_prompt2_from_response2, prompt2).item()  # response2 -> prompt2
            elif self.config.loss_type == "cosine":
                # 使用cosine similarity loss (1 - cosine_similarity)
                loss1_1 = self.cosine_loss(pred_prompt1_from_response1, prompt1).item()  # response1 -> prompt1
                loss1_2 = self.cosine_loss(pred_prompt1_from_response2, prompt1).item()  # response2 -> prompt1
                loss2_1 = self.cosine_loss(pred_prompt2_from_response1, prompt2).item()  # response1 -> prompt2
                loss2_2 = self.cosine_loss(pred_prompt2_from_response2, prompt2).item()  # response2 -> prompt2
            else:
                raise ValueError(f"Unknown loss_type: {self.config.loss_type}")
            
            # 判断是否正确预测
            test1_correct = loss1_1 < loss1_2  # response1应该比response2更好地预测prompt1
            test2_correct = loss2_2 < loss2_1  # response2应该比response1更好地预测prompt2
            
            if test1_correct:
                results['correct_predictions'] += 1
            if test2_correct:
                results['correct_predictions'] += 1
            
            # 记录详细结果（但不打印）
            test_result = {
                'test_idx': test_idx,
                'sample_indices': [int(idx1), int(idx2)],
                'loss1_1': float(loss1_1),  # response1 -> prompt1
                'loss1_2': float(loss1_2),  # response2 -> prompt1
                'loss2_1': float(loss2_1),  # response1 -> prompt2
                'loss2_2': float(loss2_2),  # response2 -> prompt2
                'test1_correct': bool(test1_correct),
                'test2_correct': bool(test2_correct),
                'margin1': float(loss1_2 - loss1_1),  # 正数表示test1正确
                'margin2': float(loss2_1 - loss2_2)   # 正数表示test2正确
            }
            
            results['detailed_results'].append(test_result)
        
        # 计算准确率
        accuracy = results['correct_predictions'] / results['total_individual_tests']
        results['accuracy'] = accuracy
        
        print(f"📊 Random Test Results: {results['correct_predictions']}/{results['total_individual_tests']} ({accuracy:.2%})")
        
        self.model.train()
        return results
    
    def _test_cross_prediction_rewritten(self):
        """使用rewritten测试数据集进行交叉预测测试 - 按dataset名字进行测试"""
        logger.info("🔬 Testing cross-prediction on rewritten datasets (6-case summary)...")

        from pathlib import Path
        from collections import defaultdict

        base_data_dir = Path("datasets")
        if not base_data_dir.exists():
            logger.warning(f"Datasets directory not found: {base_data_dir}")
            return {
                'datasets_tested': 0,
                'results': {},
                'summary': {},
                'test_subset_used': False,
            }

        # 匹配所有 vectorized_last_test_*_sae_llama3b_layers_14 目录
        dataset_dirs = [
            d for d in base_data_dir.iterdir()
            if d.is_dir() and d.name.startswith("vectorized_last_test_") and d.name.endswith("_sae_llama3b_layers_14")
        ]

        if not dataset_dirs:
            logger.warning("No rewritten test datasets found under datasets/")
            return {
                'datasets_tested': 0,
                'results': {},
                'summary': {},
                'test_subset_used': False,
            }

        # 统计结构： (domain, comp) -> {'correct': x, 'total': y}
        summary = defaultdict(lambda: {'correct': 0, 'total': 0})
        per_dataset_results = {}

        self.model.eval()

        for d in sorted(dataset_dirs):
            dataset_name = d.name.replace("vectorized_last_", "").replace("_sae_llama3b_layers_14", "")
            subset_file = d / f"test_subset_{dataset_name}.pt"

            # 推断 domain 和比较类型
            domain = (
                'helpful' if 'helpful' in dataset_name else
                'math' if 'math' in dataset_name else
                'safety' if 'safety' in dataset_name else
                'unknown'
            )
            comp = (
                'reject' if 'reject' in dataset_name else
                'rewrite' if 'rewrite' in dataset_name or 'rewritten' in dataset_name else
                'unknown'
            )

            if not subset_file.exists():
                logger.warning(f"Subset file not found for {dataset_name}: {subset_file}")
                per_dataset_results[dataset_name] = {'correct': 0, 'total': 0, 'accuracy': 0.0, 'error': 'subset_not_found'}
                continue

            try:
                subset = torch.load(subset_file, map_location="cpu")
            except Exception as e:
                logger.error(f"Failed to load subset for {dataset_name}: {e}")
                per_dataset_results[dataset_name] = {'correct': 0, 'total': 0, 'accuracy': 0.0, 'error': str(e)}
                continue

            # 取张量
            response_sae_repr = subset.get('response_sae_repr')
            prompt_embeddings = subset.get('prompt_embeddings')
            if response_sae_repr is None or prompt_embeddings is None:
                logger.error(f"Missing tensors in subset for {dataset_name}")
                per_dataset_results[dataset_name] = {'correct': 0, 'total': 0, 'accuracy': 0.0, 'error': 'missing_tensors'}
                continue

            num_samples = int(subset.get('num_samples', len(response_sae_repr)))

            # 使用提供的 start/next indices；若缺失则按偶数-紧邻奇数构造
            start_indices = subset.get('start_indices')
            next_indices = subset.get('next_indices')
            if start_indices is None or next_indices is None:
                # 保证成对
                last_even = num_samples - 1 if (num_samples % 2 == 0) else num_samples - 2
                start_indices = list(range(0, last_even, 2))
                next_indices = [i + 1 for i in start_indices]

            correct = 0
            total = 0

            # 对每个 pair：chosen = start_idx，other = next_idx
            with torch.no_grad():
                for s_idx, n_idx in zip(start_indices, next_indices):
                    s_idx = int(s_idx)
                    n_idx = int(n_idx)

                    chosen_response = response_sae_repr[s_idx:s_idx+1].float().to(self.device)
                    other_response = response_sae_repr[n_idx:n_idx+1].float().to(self.device)
                    chosen_prompt = prompt_embeddings[s_idx:s_idx+1].float().to(self.device)

                    pred_from_chosen = self.model(chosen_response)
                    pred_from_other = self.model(other_response)

                    if self.config.loss_type == "mse":
                        loss_chosen = torch.nn.functional.mse_loss(pred_from_chosen, chosen_prompt).item()
                        loss_other = torch.nn.functional.mse_loss(pred_from_other, chosen_prompt).item()
                    elif self.config.loss_type == "cosine":
                        loss_chosen = self.cosine_loss(pred_from_chosen, chosen_prompt).item()
                        loss_other = self.cosine_loss(pred_from_other, chosen_prompt).item()
                    else:
                        raise ValueError(f"Unknown loss_type: {self.config.loss_type}")

                    if loss_chosen < loss_other:
                        correct += 1
                    total += 1

            acc = (correct / total) if total > 0 else 0.0
            per_dataset_results[dataset_name] = {
                'correct': correct,
                'total': total,
                'accuracy': acc,
                'domain': domain,
                'comparison_type': comp,
            }

            # 汇总到 6 种 case 之一
            key = (domain, comp)
            summary[key]['correct'] += correct
            summary[key]['total'] += total

            logger.info(f"{dataset_name}: {correct}/{total} ({acc:.2%})")

        # 打印 6 行汇总
        def print_case(dmn: str, cp: str):
            stats = summary.get((dmn, cp), {'correct': 0, 'total': 0})
            c = stats['correct']
            t = stats['total']
            pct = (c / t * 100.0) if t > 0 else 0.0
            print(f"  {dmn} - {cp}: {c}/{t} ({pct:.2f}%)")

        print("\n📊 Rewritten 6-case summary (chosen vs immediate next):")
        for dmn in ["helpful", "math", "safety"]:
            for cp in ["reject", "rewrite"]:
                print_case(dmn, cp)

        self.model.train()

        return {
            'test_subset_used': True,
            # 只保存6个数据集的准确率，用于绘图
            'six_datasets_accuracy': {
                'helpful_reject': summary.get(('helpful', 'reject'), {'correct': 0, 'total': 0}),
                'helpful_rewrite': summary.get(('helpful', 'rewrite'), {'correct': 0, 'total': 0}),
                'math_reject': summary.get(('math', 'reject'), {'correct': 0, 'total': 0}),
                'math_rewrite': summary.get(('math', 'rewrite'), {'correct': 0, 'total': 0}),
                'safety_reject': summary.get(('safety', 'reject'), {'correct': 0, 'total': 0}),
                'safety_rewrite': summary.get(('safety', 'rewrite'), {'correct': 0, 'total': 0}),
            }
        }



def main():
    parser = argparse.ArgumentParser(description="Train prompt decoder with Accelerate")
    parser.add_argument("--config_file", type=str, default="pd_train.yaml", help="Path to YAML config file")
    parser.add_argument("--num_epochs", type=int, help="Number of Epochs")
    parser.add_argument("--hidden_layer", type=int, help="Dimension of Hidden Layer")
    parser.add_argument("--batch_size", type=int, help="Batch size (overrides YAML config)")
    parser.add_argument("--learning_rate", type=float, help="Learning rate (overrides YAML config)")
    parser.add_argument("--fine_tune", action="store_true", help="Enable fine-tuning (overrides YAML config)")
    parser.add_argument("--gpu_ids", type=str, help="GPU IDs (overrides YAML config)")
    args = parser.parse_args()
    
    # 加载YAML配置
    with open(args.config_file, 'r') as f:
        yaml_config = yaml.safe_load(f)
    
    # 从 YAML 读取 GPU 配置
    training_config = yaml_config.get("training", {})
    gpu_ids = training_config.get("gpu_ids", "0")
    num_gpus = training_config.get("num_gpus", 1)
    
    # 读取模型配置
    model_config = yaml_config.get("model", {})
    fine_tune = model_config.get("fine_tune", False)
    normalize = model_config.get("normalize", False)
    
    # 读取训练配置（命令行参数优先，然后环境变量，最后YAML配置）
    batch_size = args.batch_size if args.batch_size is not None else int(os.environ.get("BATCH_SIZE", training_config.get("batch_size", 32)))
    learning_rate = args.learning_rate if args.learning_rate is not None else float(os.environ.get("LEARNING_RATE", training_config.get("learning_rate", 1e-4)))
    fine_tune = args.fine_tune if args.fine_tune is not None else (os.environ.get("FINE_TUNE", str(fine_tune)).lower() == "true")
    gpu_ids = args.gpu_ids if args.gpu_ids is not None else training_config.get("gpu_ids", "0")
    num_epochs = args.num_epochs if args.num_epochs is not None else int(os.environ.get("NUM_EPOCHS", training_config.get("num_epochs", 5)))
    hidden_layer = args.hidden_layer if args.hidden_layer is not None else int(os.environ.get("HIDDEN_LAYER", model_config.get("hidden_layer", 0)))
    
    # 自动生成run_name
    fine_tune_str = "ft" if fine_tune else "init"
    normalize_str = "norm" if normalize else "raw"
    lr_str = f"lr{learning_rate:.0e}".replace("e-0", "e-").replace("e+0", "e")
    run_name = f"{fine_tune_str}_{normalize_str}_bs{batch_size}_{lr_str}"
    
    # 构造配置
    config = DecoderConfig(
        vectorized_data_dir=yaml_config.get("data", {}).get("vectorized_data_dir", "datasets/vectorized_sae_llama3b_layers_10"),
        train_ratio=yaml_config.get("data", {}).get("train_ratio", 0.8),
        val_ratio=yaml_config.get("data", {}).get("val_ratio", 0.1),
        test_ratio=yaml_config.get("data", {}).get("test_ratio", 0.1),
        
        sae_path=model_config.get("sae_path", "sae/sae_llama3b_layers_10.pth"),
        fine_tune=fine_tune,
        normalize=normalize,
        hidden_layer=hidden_layer,  # 使用从命令行或环境变量读取的hidden_layer
        
        batch_size=batch_size,
        learning_rate=learning_rate,
        num_epochs=num_epochs,  # 使用从命令行或环境变量读取的num_epochs
        weight_decay=float(training_config.get("weight_decay", 1e-5)),
        warmup_steps=training_config.get("warmup_steps", 1000),
        max_grad_norm=float(training_config.get("max_grad_norm", 1.0)),
        mixed_precision=training_config.get("mixed_precision", "no"),
        dataloader_pin_memory=training_config.get("dataloader_pin_memory", True),
        dataloader_num_workers=training_config.get("dataloader_num_workers", 1),
        
        output_dir=training_config.get("output_dir", "prompt_decoder"),
        intermediate_eval_frequency=training_config.get("intermediate_eval_frequency", 50),
        loss_type=training_config.get("loss_type", "mse"),
        gpu_ids=gpu_ids,  # 使用从命令行或环境变量读取的gpu_ids
        
        use_wandb=yaml_config.get("wandb", {}).get("use_wandb", False),
        wandb_project=yaml_config.get("wandb", {}).get("project", "prompt-decoder"),
        wandb_run_name=run_name  # 使用自动生成的run_name
    )
    
    # 设置随机种子
    set_seed(42)
    
    # 创建训练器并开始训练
    trainer = AccelerateTrainer(config)
    trainer.load_data()
    trainer.train()

if __name__ == "__main__":
    main() 