#!/usr/bin/env python
# coding=utf-8

import argparse
import json
import os
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple

import torch
from datasets import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer


def read_yaml(path: str) -> Dict:
    import yaml

    with open(path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)


def load_jsonl(file_path: str) -> List[Dict]:
    """Load JSONL file and return list of dictionaries"""
    data = []
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data


def build_inputs(tokenizer, prompts: List[str], responses: List[str], max_length: int) -> Dict[str, torch.Tensor]:
    """Build tokenized inputs for the reward model"""
    sep = "\n\n"
    texts = [p + sep + r for p, r in zip(prompts, responses)]
    enc = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )
    return {"input_ids": enc.input_ids, "attention_mask": enc.attention_mask}


def evaluate_rate(
    model_path: str,
    dataset_path: str,
    output_dir: str,
    tokenizer_path: str = None,
    batch_size: int = 8,
    max_length: int = 2048,
    device: str = None,
    trust_remote_code: bool = False,
    torch_dtype: str = None,
    attn_implementation: str = None,
) -> Dict:
    """Evaluate reward model rate on rewrite1 vs rewrite2 comparison"""
    
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    dtype = None
    if torch_dtype and torch_dtype != "auto":
        dtype = getattr(torch, torch_dtype)

    print(f"Loading tokenizer from {tokenizer_path or model_path}...")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path or model_path, trust_remote_code=trust_remote_code)
    
    print(f"Loading model from {model_path}...")
    model = AutoModelForSequenceClassification.from_pretrained(
        model_path,
        trust_remote_code=trust_remote_code,
        torch_dtype=dtype if dtype is not None else None,
        attn_implementation=attn_implementation,
        problem_type="regression",
        num_labels=1,
    ).to(device)
    model.eval()

    # Load data
    print(f"Loading dataset from {dataset_path}...")
    try:
        data = load_jsonl(dataset_path)
        print(f"Loaded {len(data)} samples")
    except Exception as e:
        raise ValueError(f"Failed to load dataset from {dataset_path}: {e}")
    
    # Validate data structure
    required_fields = ["source", "prompt", "chosen", "rewrite1", "rewrite2", "rewrite3"]
    for i, sample in enumerate(data):
        for field in required_fields:
            if field not in sample:
                raise ValueError(f"Sample {i} missing required field: {field}")
    
    # Group data by source
    data_by_source = defaultdict(list)
    for sample in data:
        source = sample["source"]
        if source in ["safety", "helpful", "math"]:
            data_by_source[source].append(sample)
        else:
            print(f"⚠️  Unknown source: {source}, skipping")
    
    print(f"Data distribution by source:")
    for source, samples in data_by_source.items():
        print(f"  {source}: {len(samples)} samples")
    
    # Define comparison pairs and their expected outcomes
    comparisons = [
        {
            "name": "rewrite1_vs_rewrite2",
            "field1": "rewrite1", 
            "field2": "rewrite2",
            "expected": "rewrite2_higher",  # rewrite2 should have higher reward
            "description": "Rewrite1 vs Rewrite2 (Rewrite2 should win)"
        },
        {
            "name": "rewrite1_vs_rewrite3",
            "field1": "rewrite1",
            "field2": "rewrite3", 
            "expected": "rewrite1_higher",  # rewrite1 should have higher reward
            "description": "Rewrite1 vs Rewrite3 (Rewrite1 should win)"
        },
        {
            "name": "rewrite2_vs_rewrite3",
            "field1": "rewrite2",
            "field2": "rewrite3",
            "expected": "rewrite2_higher",  # rewrite2 should have higher reward
            "description": "Rewrite2 vs Rewrite3 (Rewrite2 should win)"
        }
    ]

    # Results storage
    all_results = {
        "overall": {
            "total_sources": len(data_by_source),
            "comparisons": {},
            "results_by_source": {}
        },
        "per_source": {}
    }

    # Initialize comparison results
    for comp in comparisons:
        all_results["overall"]["comparisons"][comp["name"]] = {
            "total_samples": 0,
            "correct_predictions": 0,
            "accuracy": 0.0
        }

    # Process each source
    for source in ["safety", "helpful", "math"]:
        if source not in data_by_source:
            print(f"\n⚠️  No data found for source: {source}")
            continue
            
        source_data = data_by_source[source]
        print(f"\n🔄 Processing source: {source} ({len(source_data)} samples)")
        
        # Initialize source results
        source_results = {
            "source": source,
            "total_samples": len(source_data),
            "comparisons": {},
            "detailed_results": []
        }
        
        # Process each comparison
        for comp in comparisons:
            print(f"   🔍 {comp['description']}")
            
            correct = 0
            total = len(source_data)
            detailed_results = []
            
            # Process in batches
            for start_idx in range(0, len(source_data), batch_size):
                end_idx = min(start_idx + batch_size, len(source_data))
                batch_data = source_data[start_idx:end_idx]
                
                # Prepare batch data
                prompts = [sample["prompt"] for sample in batch_data]
                responses1 = [sample[comp["field1"]] for sample in batch_data]
                responses2 = [sample[comp["field2"]] for sample in batch_data]
                
                # Build inputs for both responses
                inputs1 = build_inputs(tokenizer, prompts, responses1, max_length)
                inputs2 = build_inputs(tokenizer, prompts, responses2, max_length)
                
                # Move to device
                for k in inputs1:
                    inputs1[k] = inputs1[k].to(device)
                    inputs2[k] = inputs2[k].to(device)
                
                # Get rewards
                with torch.no_grad():
                    rewards1 = model(**inputs1).logits.squeeze(-1)
                    rewards2 = model(**inputs2).logits.squeeze(-1)
                
                # Check predictions based on expected outcome
                if comp["expected"] == "rewrite2_higher":
                    predictions = (rewards2 > rewards1).to(torch.long)
                elif comp["expected"] == "rewrite1_higher":
                    predictions = (rewards1 > rewards2).to(torch.long)
                elif comp["expected"] == "chosen_higher":
                    predictions = (rewards1 > rewards2).to(torch.long)  # chosen is field1, should be higher
                else:
                    predictions = (rewards2 > rewards1).to(torch.long)  # default case
                
                batch_correct = int(predictions.sum().item())
                correct += batch_correct
                
                # Store detailed results
                for i, sample in enumerate(batch_data):
                    is_correct = bool(predictions[i].item())
                    detailed_results.append({
                        "sample_idx": start_idx + i,
                        "source": sample["source"],
                        f"reward_{comp['field1']}": float(rewards1[i].item()),
                        f"reward_{comp['field2']}": float(rewards2[i].item()),
                        "correct": is_correct,
                        "margin": float((rewards2[i] - rewards1[i]).item()),
                        "model": sample.get("model", "unknown"),
                        "comparison": comp["name"]
                    })
            
            # Calculate accuracy for this comparison and source
            accuracy = correct / total if total > 0 else 0.0
            
            # Store comparison results for this source
            source_results["comparisons"][comp["name"]] = {
                "total_samples": total,
                "correct_predictions": correct,
                "accuracy": accuracy,
                "detailed_results": detailed_results
            }
            
            # Add to overall comparison totals
            all_results["overall"]["comparisons"][comp["name"]]["total_samples"] += total
            all_results["overall"]["comparisons"][comp["name"]]["correct_predictions"] += correct
            
            print(f"      📊 Results: {correct}/{total} ({accuracy:.2%})")
        
        # Store source results
        all_results["per_source"][source] = source_results
        all_results["overall"]["results_by_source"][source] = {
            "total_samples": len(source_data),
            "comparisons": {comp["name"]: source_results["comparisons"][comp["name"]]["accuracy"] for comp in comparisons}
        }

    # Calculate overall statistics for each comparison
    for comp in comparisons:
        total_samples = all_results["overall"]["comparisons"][comp["name"]]["total_samples"]
        correct_predictions = all_results["overall"]["comparisons"][comp["name"]]["correct_predictions"]
        if total_samples > 0:
            all_results["overall"]["comparisons"][comp["name"]]["accuracy"] = correct_predictions / total_samples
        else:
            all_results["overall"]["comparisons"][comp["name"]]["accuracy"] = 0.0

    # Save results
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, "rate_evaluation_results.json")
    
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(all_results, f, ensure_ascii=False, indent=2)
    
    # Print summary
    print(f"\n📊 Overall Results:")
    print(f"   Sources evaluated: {len(all_results['per_source'])}")
    
    print(f"\n📊 Overall Comparison Results:")
    for comp in comparisons:
        comp_results = all_results["overall"]["comparisons"][comp["name"]]
        print(f"   {comp['description']}:")
        print(f"      {comp_results['correct_predictions']}/{comp_results['total_samples']} ({comp_results['accuracy']:.2%})")
    
    print(f"\n📊 Per-source breakdown:")
    for source in ["safety", "helpful", "math"]:
        if source in all_results["per_source"]:
            print(f"   {source}:")
            for comp in comparisons:
                comp_results = all_results["per_source"][source]["comparisons"][comp["name"]]
                print(f"      {comp['description']}: {comp_results['correct_predictions']}/{comp_results['total_samples']} ({comp_results['accuracy']:.2%})")
        else:
            print(f"   {source}: No data")
    
    print(f"\n💾 Results saved to: {output_path}")
    
    return all_results


def main():
    parser = argparse.ArgumentParser(description="Evaluate reward model rate on rewrite1 vs rewrite2 comparison")
    parser.add_argument("config", type=str, help="Path to YAML config file")
    args = parser.parse_args()

    cfg = read_yaml(args.config)

    model_path = cfg.get("model_name_or_path") or cfg.get("model_path")
    dataset_path = cfg.get("dataset_path")
    output_dir = cfg.get("output_dir") or model_path
    tokenizer_path = cfg.get("tokenizer_name_or_path")
    batch_size = int(cfg.get("per_device_eval_batch_size", 8))
    max_length = int(cfg.get("max_length", 2048))
    trust_remote_code = bool(cfg.get("trust_remote_code", False))
    torch_dtype = cfg.get("torch_dtype")
    attn_impl = cfg.get("attn_implementation")

    if not model_path:
        raise ValueError("model_name_or_path must be set in the config")
    if not dataset_path:
        raise ValueError("dataset_path must be set in the config")

    evaluate_rate(
        model_path=model_path,
        dataset_path=dataset_path,
        output_dir=output_dir,
        tokenizer_path=tokenizer_path,
        batch_size=batch_size,
        max_length=max_length,
        trust_remote_code=trust_remote_code,
        torch_dtype=torch_dtype,
        attn_implementation=attn_impl,
    )


if __name__ == "__main__":
    main()





