#!/usr/bin/env python
# coding=utf-8
"""
Compute Semantic Alignment Score (SAS) for DPO training data.

This script pre-computes anticausal scores using fixed LLaMA model, SAE, and prompt decoder.
The scores are then added to the dataset for efficient DPO training with anticausal regularization.
"""

import os
import sys
import json
import torch
import argparse
import re
from pathlib import Path
from typing import Dict, List, Tuple
from dataclasses import dataclass
from tqdm import tqdm
import yaml
import logging

# 强制使用本地 sparsify 源码
SCRIPT_DIR = Path(__file__).resolve().parent
SPARSIFY_DIR = SCRIPT_DIR.parent / "sparsify"
sys.path.insert(0, str(SPARSIFY_DIR))

from transformers import AutoTokenizer, AutoModelForCausalLM
from sparsify import Sae
from accelerate import PartialState

# 导入PromptDecoder
sys.path.insert(0, str(SCRIPT_DIR.parent))
from pd_train_accelerate import PromptDecoder

# 设置日志
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO
)
logger = logging.getLogger(__name__)


@dataclass
class SASConfig:
    """SAS计算配置"""
    # 模型配置
    llama_model_path: str = "models/llama3-8b"
    sae_path: str = "sae/sae_llama3b_layers_14.pth"
    prompt_decoder_path: str = "prompt_decoder/final_model_linear.pt"
    
    # 数据配置
    input_dataset: str = "datasets/dpo_mix_7k.jsonl"
    output_dataset: str = "datasets/dpo_mix_7k_with_sas.jsonl"
    
    # 处理配置
    batch_size: int = 8
    max_length: int = 1024
    layer_idx: int = 14
    device: str = "cuda"
    gpu_id: int = 0  # GPU ID (0-7)
    dtype: str = "float32"  # "float16", "float32"
    
    # 并行/分片配置
    num_shards: int = 1
    shard_id: int = 0
    stats_every: int = 1000
    legacy_mode: bool = False  # 新增：是否使用非流式（一次读全）模式
    
    # 输出配置
    save_format: str = "jsonl"  # "jsonl", "json"


class SASComputer:
    """Semantic Alignment Score计算器"""
    
    def __init__(self, config: SASConfig):
        self.config = config
        # 设置GPU ID
        if config.device == "cuda":
            if config.gpu_id >= 0:
                self.device = torch.device(f"cuda:{config.gpu_id}")
                # 设置当前CUDA设备
                torch.cuda.set_device(config.gpu_id)
                logger.info(f"🎯 Using GPU {config.gpu_id}: {torch.cuda.get_device_name(config.gpu_id)}")
            else:
                self.device = torch.device("cuda")
                logger.info(f"🎯 Using default GPU: {torch.cuda.get_device_name()}")
        else:
            self.device = torch.device(config.device)
        
        self.dtype = getattr(torch, config.dtype)
        
        # 加载固定组件
        self._load_llama_model()
        self._load_sae()
        self._load_prompt_decoder()
        
        logger.info("✅ All components loaded successfully")
    
    def _load_llama_model(self):
        """加载固定的LLaMA模型"""
        logger.info(f"Loading LLaMA model from: {self.config.llama_model_path}")
        
        self.llama_tokenizer = AutoTokenizer.from_pretrained(self.config.llama_model_path)
        if self.llama_tokenizer.pad_token is None:
            self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
        
        self.llama_model = AutoModelForCausalLM.from_pretrained(
            self.config.llama_model_path,
            torch_dtype=self.dtype,
            device_map={"": str(self.device)},
            low_cpu_mem_usage=True
        )
        self.llama_model.eval()
        
        logger.info("✅ LLaMA model loaded")
    
    def _load_sae(self):
        """加载SAE模型"""
        logger.info(f"Loading SAE from: {self.config.sae_path}")
        
        try:
            # 尝试从本地加载
            if Path(self.config.sae_path).exists():
                from sparsify import SparseCoder, SparseCoderConfig
                
                # 加载SAE配置和权重
                sae_path = Path(self.config.sae_path)
                meta_path = sae_path.with_suffix(".json")
                
                if meta_path.exists():
                    with open(meta_path, "r") as f:
                        meta = json.load(f)
                    
                    cfg = SparseCoderConfig.from_dict(meta)
                    self.sae = SparseCoder(
                        d_in=meta["d_in"],
                        cfg=cfg,
                        device=self.device,
                        dtype=self.dtype
                    )
                    
                    state_dict = torch.load(sae_path, map_location="cpu")
                    self.sae.load_state_dict(state_dict)
                    self.sae = self.sae.to(device=self.device, dtype=self.dtype)
                else:
                    raise FileNotFoundError(f"SAE metadata not found: {meta_path}")
            else:
                # 回退到Hub加载
                hookpoint = f"layers.{self.config.layer_idx}"
                self.sae = Sae.load_from_hub("EleutherAI/sae-llama-3-8b-32x", hookpoint=hookpoint).to(
                    device=self.device, dtype=self.dtype
                )
            
            self.sae.eval()
            logger.info("✅ SAE loaded")
            
        except Exception as e:
            logger.error(f"Failed to load SAE: {e}")
            raise
    
    def _load_prompt_decoder(self):
        """加载prompt decoder模型"""
        logger.info(f"Loading prompt decoder from: {self.config.prompt_decoder_path}")
        
        # 初始化 Accelerate 的状态以避免其 logging 抛错
        try:
            PartialState()
        except Exception as e:
            logger.warning(f"Failed to initialize Accelerate PartialState: {e}")

        checkpoint = torch.load(self.config.prompt_decoder_path, map_location="cpu")
        config = checkpoint.get('config', {})
        hidden_layer = config.get('hidden_layer', 0)
        normalize = config.get('normalize', False)
        
        self.prompt_decoder = PromptDecoder(
            device=str(self.device),
            dtype=self.dtype,
            normalize=normalize,
            hidden_layer=hidden_layer
        )
        
        self.prompt_decoder.load_state_dict(checkpoint['model_state_dict'])
        self.prompt_decoder.to(self.device)
        self.prompt_decoder.eval()
        
        logger.info("✅ Prompt decoder loaded")
    
    def _get_layer_embeddings(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        """获取指定层的last token embeddings"""
        batch_size = input_ids.shape[0]
        
        with torch.no_grad():
            outputs = self.llama_model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
            hidden_states = outputs.hidden_states
            
            # 获取指定层的输出
            layer_idx = min(self.config.layer_idx, len(hidden_states) - 1)
            layer_output = hidden_states[layer_idx]
            
            # 获取每个序列的最后一个非padding token
            seq_lengths = attention_mask.sum(dim=1) - 1
            batch_indices = torch.arange(batch_size, device=input_ids.device)
            last_token_embeddings = layer_output[batch_indices, seq_lengths]
            
            return last_token_embeddings
    
    def compute_sas_score(self, prompt_text: str, response_text: str) -> float:
        """计算单个样本的SAS score"""
        try:
            # Tokenize
            prompt_inputs = self.llama_tokenizer(
                prompt_text,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.config.max_length
            ).to(self.device)
            
            response_inputs = self.llama_tokenizer(
                response_text,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.config.max_length
            ).to(self.device)
            
            # 获取embeddings
            prompt_embeddings = self._get_layer_embeddings(
                prompt_inputs["input_ids"], 
                prompt_inputs["attention_mask"]
            )
            
            response_embeddings = self._get_layer_embeddings(
                response_inputs["input_ids"], 
                response_inputs["attention_mask"]
            )
            
            # 通过SAE处理response embeddings
            response_sparse = self.sae.encoder(response_embeddings)
            
            # 通过prompt decoder预测prompt embeddings
            predicted_prompt = self.prompt_decoder(response_sparse)
            
            # 计算MSE loss作为SAS score
            sas_score = torch.nn.functional.mse_loss(predicted_prompt, prompt_embeddings).item()
            
            return sas_score
            
        except Exception as e:
            logger.warning(f"Error computing SAS score: {e}")
            return 0.0
    
    def process_dataset_legacy(self):
        """一次性读入并处理（与之前行为一致的非流式模式）。"""
        logger.info(f"Processing dataset (legacy, non-stream): {self.config.input_dataset}")
        input_path = Path(self.config.input_dataset)
        if not input_path.exists():
            raise FileNotFoundError(f"Input dataset not found: {input_path}")
        output_path = Path(self.config.output_dataset)
        output_path.parent.mkdir(parents=True, exist_ok=True)

        # 读全
        data = []
        with open(input_path, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    data.append(json.loads(line.strip()))
        logger.info(f"Loaded {len(data)} samples into memory (legacy mode)")

        processed_rows = []
        for i, item in enumerate(tqdm(data, desc="Computing SAS scores (legacy)")):
            try:
                wrote = False
                if isinstance(item, dict) and "messages" in item:
                    messages = item["messages"]
                    if isinstance(messages, list) and len(messages) >= 2:
                        first_content = messages[0].get("content", "")
                        second_content = messages[1].get("content", "")
                        if "[RESPONSE A]" in first_content and "[RESPONSE B]" in first_content:
                            prompt_match = re.search(r'\[CONTEXT\](.*?)\[RESPONSE A\]', first_content, re.DOTALL)
                            response_a_match = re.search(r'\[RESPONSE A\](.*?)\[RESPONSE B\]', first_content, re.DOTALL)
                            response_b_match = re.search(r'\[RESPONSE B\](.*?)$', first_content, re.DOTALL)
                            if prompt_match and response_a_match and response_b_match:
                                prompt = prompt_match.group(1).strip()
                                response_a = response_a_match.group(1).strip()
                                response_b = response_b_match.group(1).strip()
                                sas_score_a = self.compute_sas_score(prompt, response_a)
                                sas_score_b = self.compute_sas_score(prompt, response_b)
                                processed_rows.append({
                                    "sas_score_a": sas_score_a,
                                    "sas_score_b": sas_score_b,
                                    "human_preference": second_content
                                })
                                wrote = True
                        if not wrote:
                            prompt_text = ""
                            response_text = ""
                            for msg in messages:
                                role = msg.get("role")
                                if role == "user" and not prompt_text:
                                    prompt_text = msg.get("content", "")
                                elif role == "assistant" and not response_text:
                                    response_text = msg.get("content", "")
                            if prompt_text and response_text:
                                sas_score = self.compute_sas_score(prompt_text, response_text)
                                processed_rows.append({"sas_score": sas_score})
                                wrote = True
                elif "prompt" in item and "response" in item:
                    sas_score = self.compute_sas_score(item["prompt"], item["response"])
                    processed_rows.append({"sas_score": sas_score})
                    wrote = True
                elif "chosen" in item and "rejected" in item:
                    chosen_prompt = item.get("prompt", "")
                    if chosen_prompt:
                        chosen_sas = self.compute_sas_score(chosen_prompt, item["chosen"])
                        rejected_sas = self.compute_sas_score(chosen_prompt, item["rejected"])
                        processed_rows.append({
                            "chosen_sas_score": chosen_sas,
                            "rejected_sas_score": rejected_sas
                        })
                        wrote = True
                if not wrote:
                    processed_rows.append({"warning": "unknown_format"})
            except Exception as e:
                logger.error(f"Error processing item {i}: {e}")
                processed_rows.append({"error": str(e)})

        # 写出
        with open(output_path, "w", encoding="utf-8") as fout:
            for obj in processed_rows:
                fout.write(json.dumps(obj, ensure_ascii=False) + "\n")
        logger.info(f"✅ Legacy write completed: {output_path} rows={len(processed_rows)}")

    def process_dataset(self):
        if self.config.legacy_mode:
            return self.process_dataset_legacy()
        # 原流式实现
        logger.info(f"Processing dataset (stream): {self.config.input_dataset}")
        
        input_path = Path(self.config.input_dataset)
        if not input_path.exists():
            raise FileNotFoundError(f"Input dataset not found: {input_path}")
        
        output_path = Path(self.config.output_dataset)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        total_read = 0
        total_written = 0
        rlhf_pairs_written = 0
        single_written = 0
        dpo_written = 0
        
        with open(input_path, "r", encoding="utf-8") as fin, open(output_path, "w", encoding="utf-8") as fout:
            for line_idx, line in enumerate(fin):
                if not line.strip():
                    continue
                # 分片：仅处理属于本分片的行
                if self.config.num_shards > 1 and (line_idx % self.config.num_shards) != self.config.shard_id:
                    continue
                total_read += 1
                try:
                    item = json.loads(line.strip())
                    wrote = False
                    
                    # 处理messages格式（可能是RLHF A/B）
                    if isinstance(item, dict) and "messages" in item:
                        messages = item["messages"]
                        if isinstance(messages, list) and len(messages) >= 2:
                            first_content = messages[0].get("content", "")
                            second_content = messages[1].get("content", "")
                            if "[RESPONSE A]" in first_content and "[RESPONSE B]" in first_content:
                                # 提取prompt/A/B
                                prompt_match = re.search(r'\[CONTEXT\](.*?)\[RESPONSE A\]', first_content, re.DOTALL)
                                response_a_match = re.search(r'\[RESPONSE A\](.*?)\[RESPONSE B\]', first_content, re.DOTALL)
                                response_b_match = re.search(r'\[RESPONSE B\](.*?)$', first_content, re.DOTALL)
                                if prompt_match and response_a_match and response_b_match:
                                    prompt = prompt_match.group(1).strip()
                                    response_a = response_a_match.group(1).strip()
                                    response_b = response_b_match.group(1).strip()
                                    sas_score_a = self.compute_sas_score(prompt, response_a)
                                    sas_score_b = self.compute_sas_score(prompt, response_b)
                                    out_obj = {
                                        "sas_score_a": sas_score_a,
                                        "sas_score_b": sas_score_b,
                                        "human_preference": second_content
                                    }
                                    fout.write(json.dumps(out_obj, ensure_ascii=False) + "\n")
                                    total_written += 1
                                    rlhf_pairs_written += 1
                                    wrote = True
                            # 否则尝试user/assistant单轮
                            if not wrote:
                                prompt_text = ""
                                response_text = ""
                                for msg in messages:
                                    role = msg.get("role")
                                    if role == "user" and not prompt_text:
                                        prompt_text = msg.get("content", "")
                                    elif role == "assistant" and not response_text:
                                        response_text = msg.get("content", "")
                                if prompt_text and response_text:
                                    sas_score = self.compute_sas_score(prompt_text, response_text)
                                    fout.write(json.dumps({"sas_score": sas_score}, ensure_ascii=False) + "\n")
                                    total_written += 1
                                    single_written += 1
                                    wrote = True
                    # 直接prompt/response格式
                    elif "prompt" in item and "response" in item:
                        sas_score = self.compute_sas_score(item["prompt"], item["response"])
                        fout.write(json.dumps({"sas_score": sas_score}, ensure_ascii=False) + "\n")
                        total_written += 1
                        single_written += 1
                        wrote = True
                    # DPO格式（chosen/rejected）
                    elif "chosen" in item and "rejected" in item:
                        chosen_prompt = item.get("prompt", "")
                        if chosen_prompt:
                            chosen_sas = self.compute_sas_score(chosen_prompt, item["chosen"])
                            rejected_sas = self.compute_sas_score(chosen_prompt, item["rejected"])
                            fout.write(json.dumps({
                                "chosen_sas_score": chosen_sas,
                                "rejected_sas_score": rejected_sas
                            }, ensure_ascii=False) + "\n")
                            total_written += 1
                            dpo_written += 1
                            wrote = True
                    
                    if not wrote:
                        # 未能识别的格式，写入一个占位以便排查
                        fout.write(json.dumps({"warning": "unknown_format"}, ensure_ascii=False) + "\n")
                        total_written += 1
                    
                except Exception as e:
                    logger.error(f"Error processing line {line_idx}: {e}")
                    fout.write(json.dumps({"error": str(e)}, ensure_ascii=False) + "\n")
                    total_written += 1
                
                if total_read % self.config.stats_every == 0:
                    logger.info(
                        f"Processed {total_read} lines in shard {self.config.shard_id}/{self.config.num_shards}; "
                        f"written={total_written}, rlhf_pairs={rlhf_pairs_written}, single={single_written}, dpo={dpo_written}"
                    )
        
        logger.info(
            f"✅ Done. shard {self.config.shard_id}/{self.config.num_shards} | read={total_read}, "
            f"written={total_written}, rlhf_pairs={rlhf_pairs_written}, single={single_written}, dpo={dpo_written}. "
            f"Output -> {output_path}"
        )


def main():
    parser = argparse.ArgumentParser(description="Compute Semantic Alignment Score (SAS) for DPO data")
    parser.add_argument("--config_file", type=str, default="compute_sas_scores.yaml", help="Path to YAML config file")
    parser.add_argument("--gpu_id", type=int, default=None, help="GPU ID to use (0-7, overrides config file)")
    parser.add_argument("--num_shards", type=int, default=None, help="Total number of shards for parallel runs")
    parser.add_argument("--shard_id", type=int, default=None, help="Current shard id [0..num_shards-1]")
    parser.add_argument("--legacy", action="store_true", help="Use legacy non-streaming mode (load all into memory)")
    args = parser.parse_args()
    
    # 加载配置
    with open(args.config_file, 'r') as f:
        yaml_config = yaml.safe_load(f)
    
    # 构造配置
    config = SASConfig(
        llama_model_path=yaml_config.get("model", {}).get("llama_path", "models/llama3-8b"),
        sae_path=yaml_config.get("sae", {}).get("path", "sae/sae_llama3b_layers_14.pth"),
        prompt_decoder_path=yaml_config.get("prompt_decoder", {}).get("path", "prompt_decoder/final_model_linear.pt"),
        input_dataset=yaml_config.get("data", {}).get("input_dataset", "datasets/dpo_mix_7k.jsonl"),
        output_dataset=yaml_config.get("data", {}).get("output_dataset", "datasets/dpo_mix_7k_with_sas.jsonl"),
        batch_size=yaml_config.get("processing", {}).get("batch_size", 8),
        max_length=yaml_config.get("processing", {}).get("max_length", 1024),
        layer_idx=yaml_config.get("processing", {}).get("layer_idx", 14),
        device=yaml_config.get("processing", {}).get("device", "cuda"),
        gpu_id=args.gpu_id if args.gpu_id is not None else yaml_config.get("processing", {}).get("gpu_id", 0),
        dtype=yaml_config.get("processing", {}).get("dtype", "float32"),
        save_format=yaml_config.get("output", {}).get("format", "jsonl"),
        num_shards=args.num_shards if args.num_shards is not None else yaml_config.get("processing", {}).get("num_shards", 1),
        shard_id=args.shard_id if args.shard_id is not None else yaml_config.get("processing", {}).get("shard_id", 0),
        stats_every=yaml_config.get("processing", {}).get("stats_every", 1000),
        legacy_mode=args.legacy
    )
    
    # 创建SAS计算器并处理数据
    computer = SASComputer(config)
    computer.process_dataset()


if __name__ == "__main__":
    main()
