import argparse
import numpy as np
from datasets import load_dataset, load_from_disk
from tqdm import tqdm
import torch
from cloud_audio_new.model import CLoudRewardModel
from transformers import AutoProcessor
from accelerate import Accelerator
import json

def set_seed(seed=17):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

###########
# Build eval data
###########

def load_reward_bench(dataset_path, split, accelerator):
    # data = load_dataset("allenai/reward-bench")["filtered"]
    data = load_from_disk(dataset_path)[split]
    total_size = len(data)
    per_rank = total_size // accelerator.num_processes
    start = accelerator.process_index * per_rank
    end = start + per_rank if accelerator.process_index != accelerator.num_processes - 1 else total_size
    shard = data.select(range(start, end))
    eval_data = []
    eval_metadata = []
    for example in shard:
        eval_data.append({
            "id": f"{example['id']}-chosen",
            "prompt": example["prompt"],
            # "prompt": "Generate a high-quality speech sound.",
            "response": example["chosen"]
        })
        eval_data.append({
            "id": f"{example['id']}-rejected",
            "prompt": example["prompt"],
            # "prompt": "Generate a high-quality speech sound.",
            "response": example["rejected"]
        })
    for item in data:
        eval_metadata.append({
            "id": str(item["id"]),
        })
    return eval_data, eval_metadata


###########
# Post-process Scores
###########

def post_process_reward_bench(eval_metadata, rewards, accelerator, save_path):
    if rewards is None:  # 非主进程直接返回
        return
    result=[]
    infer_result=[]
    for example in eval_metadata:
        id_ = example["id"]
        chosen_reward = rewards[id_ + "-chosen"]
        rejected_reward = rewards[id_ + "-rejected"]
        correct = int(chosen_reward > rejected_reward)
        result.append(correct)
        infer_result.append({
            "id": id_,
            "correct": correct
        })

    acc = np.mean(result) * 100
    accelerator.print(f"Data size: {len(result)}")
    accelerator.print(f"Pairwise acc: {acc:.2f}%")

    with open(save_path, 'w', encoding='utf-8') as outfile:
        json.dump(infer_result, outfile, ensure_ascii=False, indent=2)
    return result



###########
# Scoring
###########

def generate_rewards(model, processor, eval_data, batch_size, accelerator):
    rewards = {}
    raw_model = model.module if hasattr(model, "module") else model


    for i in tqdm(range(0, len(eval_data), batch_size)):
        batch = eval_data[i:i+batch_size]
        
        prompts = [item["prompt"] for item in batch]
        responses = [item["response"] for item in batch]
        ids = [item["id"] for item in batch]

        with torch.no_grad():
            batch_rewards, _ = raw_model.predict_reward(prompts, responses, processor)
        
        for id_, reward in zip(ids, batch_rewards):
            rewards[id_] = reward

    items = list(rewards.items())

    # 收集所有进程的rewards
    all_rewards = accelerator.gather_for_metrics(items)
    if accelerator.is_main_process:
        # 合并字典（需自定义合并逻辑）
        merged_rewards = dict()
        for k, v in all_rewards:
            merged_rewards[k]=v
        accelerator.print("rewards counts: ", len(merged_rewards))   
        return merged_rewards
    return None




if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, required=True)
    parser.add_argument("--dataset-path", type=str, required=True)
    parser.add_argument("--split", type=str, default="test")
    parser.add_argument("--save-path", type=str, default="infer_results/save.json")


    # HF args
    parser.add_argument("--batch-size", type=int, default=1)
    args = parser.parse_args()

    set_seed(17)
    accelerator = Accelerator()
    accelerator.print("Loading model...")
    
    model = CLoudRewardModel.from_pretrained(args.model_path)
    model = accelerator.prepare(model)
    processor = AutoProcessor.from_pretrained("models/Qwen2-Audio-7B-Instruct")

    accelerator.print("Loading dataset...")
    eval_data, eval_metadata = load_reward_bench(args.dataset_path, args.split, accelerator)

    accelerator.print("Generating rewards...")
    rewards = generate_rewards(model, processor, eval_data, args.batch_size, accelerator)

    accelerator.print("Evaluating...")
    if accelerator.is_main_process:
        post_process_reward_bench(eval_metadata, rewards, accelerator, args.save_path)
    accelerator.wait_for_everyone()
