# [August 30, 2025]
# ==============================================================================

from __future__ import annotations

import argparse
import csv
import json
import os
from typing import Any, NamedTuple

import deepspeed
import numpy as np
import torch
import torch.distributed as dist
from rich.console import Console
from rich.table import Table
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations.deepspeed import HfDeepSpeedConfig
from transformers.utils import is_torch_bf16_gpu_available, is_torch_tf32_available

from safe_rlhf.configs import get_deepspeed_eval_config
from safe_rlhf.logger import set_logger_level
from safe_rlhf.models import AutoModelForScore, load_pretrained_models
from safe_rlhf.utils import is_main_process, seed_everything, str2bool

class ArenaOutput(NamedTuple):
    prompt: str
    response: str
    reward: float
    cost: float

class ResponseDataset(Dataset):
    """Dataset for responses from two .jsonl files."""

    def __init__(self, response_file: str):
        super().__init__()
        with open(response_file, 'r', encoding='utf-8') as f:
            self.data = [json.loads(line) for line in f]

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int) -> dict[str, Any]:
        return {
            'prompt': self.data[index]['prompt'],
            'response_text': self.data[index]['response_text'],
        }

def parse_arguments() -> argparse.Namespace:
    """Parse the command-line arguments."""
    parser = argparse.ArgumentParser(
        prog='deepspeed --module safe_rlhf.evaluate.arena',
        description='Evaluate the performance of two models in an arena from their generated responses.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    # Input files
    model_parser = parser.add_argument_group('model')
    model_parser.add_argument(
        '--response_file',
        type=str,
        help='Path to the JSONL file containing responses from the model.',
        required=True,
    )
    model_parser.add_argument(
        '--model_name',
        type=str,
        help='The display name for the model.',
        required=True,
    )
    model_parser.add_argument(
        '--reward_model_name_or_path',
        type=str,
        help='The name or path of the reward model to load from.',
        required=True,
    )
    model_parser.add_argument(
        '--cost_model_name_or_path',
        type=str,
        help='The name or path of the cost model to load from.',
    )
    model_parser.add_argument(
        '--max_length',
        type=int,
        default=512,
        help='The maximum sequence length of the model.',
    )
    model_parser.add_argument(
        '--trust_remote_code',
        type=str2bool,
        default=False,
        help='Whether to trust the remote code.',
    )
    # Evaluation
    evaluation_parser = parser.add_argument_group('evaluation')
    evaluation_parser.add_argument(
        '--per_device_eval_batch_size',
        type=int,
        default=16,
        help='Batch size (per device) for the evaluation dataloader.',
    )
    evaluation_parser.add_argument('--seed', type=int, default=42)
    evaluation_parser.add_argument('--fp16', type=str2bool, default=False)
    evaluation_parser.add_argument('--bf16', type=str2bool, default=False)
    evaluation_parser.add_argument('--tf32', type=str2bool, default=None)
    # Logging
    logging_parser = parser.add_argument_group('logging')
    logging_parser.add_argument('--output_dir', type=str, required=True, help='Where to store the evaluation output.')
    # DeepSpeed
    deepspeed_parser = parser.add_argument_group('deepspeed')
    deepspeed_parser.add_argument('--local_rank', type=int, default=-1, help='Local rank for distributed training on GPUs')
    deepspeed_parser.add_argument('--zero_stage', type=int, default=0, choices=[0, 1, 2, 3], help='ZeRO optimization stage for models.')
    parser = deepspeed.add_config_arguments(parser)
    
    args = parser.parse_args()
    if args.local_rank == -1:
        parser.error('`local_rank` not set, please use DeepSpeed launcher to run this script.')
    # ... (rest of validation can remain the same)
    return args

def score_texts(
    texts: list[str],
    reward_model: PreTrainedModel,
    cost_model: PreTrainedModel | None,
    reward_tokenizer: PreTrainedTokenizerBase,
    cost_tokenizer: PreTrainedTokenizerBase | None,
    args: argparse.Namespace,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Score a list of texts using reward and cost models."""
    device = args.device
    reward_inputs = reward_tokenizer(
        texts, padding=True, truncation=True, max_length=args.max_length, return_tensors='pt'
    ).to(device)
    with torch.no_grad():
        reward_scores = reward_model(
            reward_inputs['input_ids'], attention_mask=reward_inputs['attention_mask']
        ).end_scores.squeeze(dim=-1)
    cost_scores = torch.zeros_like(reward_scores)
    if cost_model is not None and cost_tokenizer is not None:
        cost_inputs = cost_tokenizer(
            texts, padding=True, truncation=True, max_length=args.max_length, return_tensors='pt'
        ).to(device)
        with torch.no_grad():
            cost_scores = cost_model(
                cost_inputs['input_ids'], attention_mask=cost_inputs['attention_mask']
            ).end_scores.squeeze(dim=-1)
    return reward_scores, cost_scores

def main() -> None:
    args = parse_arguments()
    deepspeed.init_distributed()
    args.global_rank = dist.get_rank()
    args.device = torch.device('cuda', args.local_rank)
    torch.cuda.set_device(args.device)
    seed_everything(args.seed)
    set_logger_level()
    dist.barrier()

    ds_config = get_deepspeed_eval_config(stage=args.zero_stage, fp16=args.fp16, bf16=args.bf16)
    if ds_config['zero_optimization']['stage'] == 3:
        args.dschf = HfDeepSpeedConfig(ds_config)

    reward_model, reward_tokenizer = load_pretrained_models(
        args.reward_model_name_or_path, model_max_length=args.max_length, padding_side='right', auto_model_type=AutoModelForScore, trust_remote_code=args.trust_remote_code,
    )
    cost_model, cost_tokenizer = (
        load_pretrained_models(args.cost_model_name_or_path, model_max_length=args.max_length, padding_side='right', auto_model_type=AutoModelForScore, trust_remote_code=args.trust_remote_code)
        if args.cost_model_name_or_path
        else (None, None)
    )
    reward_model, *_ = deepspeed.initialize(model=reward_model, config=ds_config)
    reward_model.eval()
    if cost_model:
        cost_model, *_ = deepspeed.initialize(model=cost_model, config=ds_config)
        cost_model.eval()

    dataset = ResponseDataset(args.response_file)
    dataloader = DataLoader(dataset, sampler=DistributedSampler(dataset, shuffle=False), batch_size=args.per_device_eval_batch_size)
    
    model_name = args.model_name
        
    local_results = []
    for batch in tqdm(dataloader, desc='Evaluating', disable=not is_main_process()):
        reward, cost = score_texts(batch['response_text'], reward_model, cost_model, reward_tokenizer, cost_tokenizer, args)

        for i in range(len(batch['prompt'])):
            prompt_full, _, _ = batch['response_text'][i].partition('ASSISTANT:')
            _, _, output = batch['response_text'][i].partition('ASSISTANT:')
            local_results.append(
                ArenaOutput(
                    prompt=prompt_full.strip(),
                    response=output.strip(),
                    reward=reward[i].item(),
                    cost=cost[i].item(),
                )
            )

    dist.barrier()
    gathered_results_list = [None for _ in range(dist.get_world_size())]
    if is_main_process():
        dist.gather_object(local_results, gathered_results_list, dst=0)
    else:
        dist.gather_object(local_results, None, dst=0)

    if not is_main_process():
        return
    
    table_rows = [item for sublist in gathered_results_list for item in sublist]
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Save CSV
    output_file = os.path.join(args.output_dir, 'evaluation_results.csv')
    with open(output_file, mode='w', encoding='utf-8', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Prompt', model_name, 'Reward', 'Cost'])
        for row in table_rows:
            writer.writerow([row.prompt, row.response, row.reward, row.cost])


if __name__ == '__main__':
    main()