# evaluate_ddp.py

import argparse
import json
import logging
import os
import re
from typing import List, Dict, Any

import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import BitsAndBytesConfig
from peft import LoraConfig, TaskType, get_peft_model

# DDP: 相关导入
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

try:
    from twnm.models.twnm_pretrained_model import TWNM, TWNMConfig
except ImportError:
    print("错误：无法导入 'TWNM' 和 'TWNMConfig'。")
    print("请确保 'evaluate_ddp.py' 脚本位于可以访问 'models' 包的项目目录中。")
    exit(1)

# AudioDataset 类，根据您的版本，只修改日志打印部分以适应 DDP
class AudioDataset(Dataset):
    def __init__(self, data_file, sample_rate=44100):
        super().__init__()
        self.data_list = []
        self.data_dir = os.path.dirname(data_file)
        with open(data_file, 'r', encoding='utf8') as fin:
            for line in fin:
                self.data_list.append(json.loads(line))
        self.sample_rate = sample_rate
        # DDP: 只在主进程打印日志
        if int(os.environ.get("RANK", 0)) == 0:
            logging.info(f"成功加载 {len(self.data_list)} 条数据来自 {data_file}")

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        json_obj = self.data_list[index]
        audio_path = json_obj["source_metadata"]["audio_path"]
        if not os.path.isabs(audio_path):
             audio_path = os.path.join(os.path.dirname(self.data_dir), audio_path)
        try:
            waveform, original_sr = torchaudio.load(audio_path)
            if original_sr != self.sample_rate:
                resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=self.sample_rate)
                waveform = resampler(waveform)
            # 您的音频处理逻辑
            waveform = waveform[:2, :]
        except Exception as e:
            if int(os.environ.get("RANK", 0)) == 0:
                logging.error(f"加载或处理音频文件失败: {audio_path}, 错误: {e}")
            waveform = torch.zeros((2, self.sample_rate * 5)) # 保持2通道

        question_data = json_obj["question_data"]
        question = question_data["question"]
        options = question_data["options"]
        answer_key = question_data["answer"]
        choices_str = "\n".join([f"{key}: {value}" for key, value in options.items()])
        task_prompt = f"{question}. Please choose the answer from the following options: {choices_str}"
        solution = answer_key
        return { "audio": waveform, "task": task_prompt, "solution": solution, "scene_id": json_obj.get("scene_id", f"item_{index}"), "task_type": json_obj.get("task_type", "unknown"), "question": question }

# 您的 parse_answer 函数
def parse_answer(text: str) -> str:
    match = re.search(r'\|\<answer\>\|(.*?)\|\</answer\>\|', text, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return "N/A"

def main():
    parser = argparse.ArgumentParser(description="GRPO 模型 DDP 评测脚本")
    # 您的参数解析
    parser.add_argument("--sft_checkpoint_path", type=str, required=True, help="SFT阶段训练好的模型检查点文件路径 (*.bin)")
    parser.add_argument("--policy_adapter_path", type=str, required=True, help="GRPO训练产出的LoRA适配器目录")
    parser.add_argument("--test_data_file", type=str, required=True, help="评测用的jsonl数据文件")
    parser.add_argument("--output_file", type=str, required=True, help="保存详细评测结果的jsonl文件路径")
    parser.add_argument("--model_mode", type=str, required=True, choices=['policy', 'sft'], help="评测模式")
    parser.add_argument("--batch_size", type=int, default=1, help="每个 GPU 的批处理大小。")
    args = parser.parse_args()

    # DDP: 初始化进程组
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = dist.get_world_size()
    torch.cuda.set_device(local_rank)
    device = f"cuda:{local_rank}"

    # DDP: 只在主进程 (rank 0) 配置和打印日志
    if local_rank == 0:
        logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

    if local_rank == 0: logging.info("--- 1. 开始加载和组装模型 ---")
    
    quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)
    temp_lora_config = LoraConfig(target_modules=["q_proj", "v_proj"], task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.1)
    twnm_config = TWNMConfig()
    model = TWNM(config=twnm_config, peft_config=temp_lora_config, quantization_config=quantization_config)
    
    if local_rank == 0: logging.info(f"正在从 {args.sft_checkpoint_path} 加载SFT权重...")
    state_dict = torch.load(args.sft_checkpoint_path, map_location="cpu")
    new_state_dict = {k[7:] if k.startswith('module.') else k: v for k, v in state_dict.items()}
    model.load_state_dict(new_state_dict, strict=False)
    
    model.decoder = model.decoder.merge_and_unload()
    policy_lora_config = LoraConfig(
        target_modules=["q_proj", "v_proj"], 
        task_type=TaskType.CAUSAL_LM, 
        r=8, 
        lora_alpha=32, 
        lora_dropout=0.1, 
        inference_mode=False
    )
    model.decoder = get_peft_model(model.decoder, policy_lora_config, adapter_name="policy")
    
    # DDP 注意: 假设在评测 policy 时，此行会取消注释
    # model.decoder.load_adapter(args.policy_adapter_path, adapter_name="policy")
    if local_rank == 0: logging.info(f"已从 {args.policy_adapter_path} 加载训练好的 'policy' 适配器权重。")

    model.to(device)
    
    # DDP: 封装模型
    model = DDP(model, device_ids=[local_rank])
    model.eval()
    
    # DDP: 通过 .module 访问原始模型属性
    tokenizer = model.module.tokenizer
    if local_rank == 0: logging.info("--- 模型加载和组装全部完成 ---")

    if local_rank == 0: logging.info("--- 2. 加载评测数据集 ---")
    eval_dataset = AudioDataset(args.test_data_file)
    # DDP: 使用 DistributedSampler
    sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=local_rank, shuffle=False)
    eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size, sampler=sampler)

    if local_rank == 0: logging.info(f"--- 3. 设置模型模式为: {args.model_mode} ---")
    # DDP: 通过 .module 访问原始模型方法
    if args.model_mode == 'policy':
        model.module.decoder.set_adapter("policy")
    elif args.model_mode == 'sft':
        model.module.decoder.disable_adapter()

    if local_rank == 0: logging.info("--- 4. 开始评测循环 ---")
    # DDP: 每个进程只保存自己的结果
    local_results = []
    
    progress_bar = tqdm(eval_dataloader, desc=f"Evaluating on Rank {local_rank}", disable=(local_rank != 0))

    for batch in progress_bar:
        audios = batch['audio'].to(device)
        tasks = batch['task']
        solutions = batch['solution']
        
        inputs = tokenizer(tasks, return_tensors="pt", padding=True, truncation=True)
        input_ids = inputs.input_ids.to(device)

        with torch.no_grad():
            # DDP: 通过 .module 调用 generate
            generated_ids = model.module.generate(
                input_ids=input_ids, audio=audios, max_new_tokens=1024,
                do_sample=False, pad_token_id=tokenizer.eos_token_id
            )

        decoded_outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
        decoded_outputs = [out.split(tokenizer.eos_token)[0].strip() for out in decoded_outputs]
        
        for i in range(len(decoded_outputs)):
            parsed = parse_answer(decoded_outputs[i])
            is_correct = (parsed == solutions[i])
            
            result_item = {
                "scene_id": batch["scene_id"][i], "task_type": batch["task_type"][i],
                "task": tasks[i], # DDP: 直接从 tasks 列表获取
                "ground_truth": solutions[i], "model_mode": args.model_mode,
                "parsed_answer": parsed, "is_correct": is_correct,
                "raw_output": decoded_outputs[i]
            }
            # DDP: 追加到本地列表
            local_results.append(result_item)

    # DDP: 汇总所有进程的结果
    all_process_results = [None] * world_size
    dist.all_gather_object(all_process_results, local_results)

    # DDP: 只有主进程 (rank 0) 进行写入文件和打印报告
    if local_rank == 0:
        logging.info("--- 5. 评测完成，汇总结果并生成报告 ---")
        final_results = [item for sublist in all_process_results for item in sublist]
        correct_count = sum(1 for item in final_results if item['is_correct'])
        total_count = len(final_results)

        with open(args.output_file, 'w', encoding='utf-8') as f_out:
            for item in final_results:
                f_out.write(json.dumps(item, ensure_ascii=False) + '\n')
        
        accuracy = (correct_count / total_count) * 100 if total_count > 0 else 0
        
        print("\n" + "="*50)
        print("           评 测 报 告 (DDP)")
        print("="*50)
        print(f"  模型模式 (Model Mode): {args.model_mode.upper()}")
        print(f"  测试数据文件: {args.test_data_file}")
        print(f"  详细结果已保存至: {args.output_file}")
        print(f"  使用的 GPU 数量: {world_size}")
        print("-"*50)
        print(f"  总样本数: {total_count}")
        print(f"  正确预测数: {correct_count}")
        print(f"  => 准确率 (Accuracy): {accuracy:.2f}%")
        print("="*50)

if __name__ == "__main__":
    main()

"""
CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=1,7 torchrun --nproc_per_node=2 evaluation.py \
  --sft_checkpoint_path assets/checkpoints/sft2_checkpoint-2502/pytorch_model.bin \
  --policy_adapter_path assets/checkpoints/grpo_checkpoint-1/policy/policy \
  --test_data_file /data2/wl/RL_train/output/benchmark/benchmark_questions.jsonl \
  --output_file results_sft_ddp.jsonl \
  --model_mode sft \
  --batch_size 6 # 这是每个GPU的batch_size，总batch_size为 2 * 6 = 12

"""