#!/usr/bin/env python
# coding=utf-8

import argparse
import json
import os
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple

import torch
from datasets import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer


def read_yaml(path: str) -> Dict:
    import yaml

    with open(path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)


def load_jsonl(file_path: str) -> List[Dict]:
    """Load JSONL file and return list of dictionaries"""
    data = []
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data


def build_inputs(tokenizer, prompts: List[str], responses: List[str], max_length: int) -> Dict[str, torch.Tensor]:
    """Build tokenized inputs for the reward model"""
    sep = "\n\n"
    texts = [p + sep + r for p, r in zip(prompts, responses)]
    enc = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )
    return {"input_ids": enc.input_ids, "attention_mask": enc.attention_mask}


def evaluate_robustness(
    model_path: str,
    dataset_dir: str,
    output_dir: str,
    tokenizer_path: str = None,
    batch_size: int = 8,
    max_length: int = 2048,
    device: str = None,
    trust_remote_code: bool = False,
    torch_dtype: str = None,
    attn_implementation: str = None,
) -> Dict:
    """Evaluate reward model robustness on rewritten datasets"""
    
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    dtype = None
    if torch_dtype and torch_dtype != "auto":
        dtype = getattr(torch, torch_dtype)

    print(f"Loading tokenizer from {tokenizer_path or model_path}...")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path or model_path, trust_remote_code=trust_remote_code)
    
    print(f"Loading model from {model_path}...")
    model = AutoModelForSequenceClassification.from_pretrained(
        model_path,
        trust_remote_code=trust_remote_code,
        torch_dtype=dtype if dtype is not None else None,
        attn_implementation=attn_implementation,
        problem_type="regression",
        num_labels=1,
    ).to(device)
    model.eval()

    # Find all JSONL files in the dataset directory
    dataset_dir = Path(dataset_dir)
    jsonl_files = list(dataset_dir.glob("*.jsonl"))
    
    if not jsonl_files:
        raise ValueError(f"No JSONL files found in {dataset_dir}")
    
    print(f"Found {len(jsonl_files)} dataset files: {[f.name for f in jsonl_files]}")

    # Results storage
    all_results = {
        "overall": {
            "total_datasets": len(jsonl_files),
            "results_by_dataset": {}
        },
        "per_dataset": {}
    }

    # Process each dataset
    for jsonl_file in sorted(jsonl_files):
        dataset_name = jsonl_file.stem  # filename without extension
        print(f"\n🔄 Processing dataset: {dataset_name}")
        
        # Load data
        try:
            data = load_jsonl(jsonl_file)
            print(f"   Loaded {len(data)} samples")
        except Exception as e:
            print(f"   ❌ Failed to load {jsonl_file}: {e}")
            continue
        
        # Validate data structure
        if len(data) % 2 != 0:
            print(f"   ⚠️  Warning: Dataset {dataset_name} has odd number of samples ({len(data)}), skipping last sample")
            data = data[:-1]
        
        # Group data into pairs: (0,1), (2,3), (4,5), etc.
        pairs = []
        for i in range(0, len(data), 2):
            if i + 1 < len(data):
                # Check if prompts match
                if data[i]["prompt"] != data[i+1]["prompt"]:
                    print(f"   ⚠️  Warning: Prompts don't match for pair {i//2}: skipping")
                    continue
                    
                pairs.append({
                    "prompt": data[i]["prompt"],
                    "response_0": data[i]["response"],     # Even index (should be chosen)
                    "response_1": data[i+1]["response"],   # Odd index (should be rejected)
                    "pair_idx": i // 2
                })
        
        print(f"   Created {len(pairs)} valid pairs")
        
        if not pairs:
            print(f"   ❌ No valid pairs found in {dataset_name}")
            continue
        
        # Evaluate pairs
        correct = 0
        total = len(pairs)
        detailed_results = []
        
        # Process in batches
        for start_idx in range(0, len(pairs), batch_size):
            end_idx = min(start_idx + batch_size, len(pairs))
            batch_pairs = pairs[start_idx:end_idx]
            
            # Prepare batch data
            prompts_0 = [pair["prompt"] for pair in batch_pairs]
            responses_0 = [pair["response_0"] for pair in batch_pairs]
            prompts_1 = [pair["prompt"] for pair in batch_pairs]
            responses_1 = [pair["response_1"] for pair in batch_pairs]
            
            # Build inputs for both responses
            inputs_0 = build_inputs(tokenizer, prompts_0, responses_0, max_length)
            inputs_1 = build_inputs(tokenizer, prompts_1, responses_1, max_length)
            
            # Move to device
            for k in inputs_0:
                inputs_0[k] = inputs_0[k].to(device)
                inputs_1[k] = inputs_1[k].to(device)
            
            # Get rewards
            with torch.no_grad():
                rewards_0 = model(**inputs_0).logits.squeeze(-1)  # Even index responses
                rewards_1 = model(**inputs_1).logits.squeeze(-1)  # Odd index responses
            
            # Check predictions (even index should have higher reward)
            predictions = (rewards_0 > rewards_1).to(torch.long)
            batch_correct = int(predictions.sum().item())
            correct += batch_correct
            
            # Store detailed results
            for i, pair in enumerate(batch_pairs):
                is_correct = bool(predictions[i].item())
                detailed_results.append({
                    "pair_idx": pair["pair_idx"],
                    "reward_0": float(rewards_0[i].item()),
                    "reward_1": float(rewards_1[i].item()),
                    "correct": is_correct,
                    "margin": float((rewards_0[i] - rewards_1[i]).item())
                })
        
        # Calculate accuracy
        accuracy = correct / total if total > 0 else 0.0
        
        # Store results
        dataset_results = {
            "dataset_name": dataset_name,
            "total_pairs": total,
            "correct_predictions": correct,
            "accuracy": accuracy,
            "detailed_results": detailed_results
        }
        
        all_results["per_dataset"][dataset_name] = dataset_results
        all_results["overall"]["results_by_dataset"][dataset_name] = {
            "accuracy": accuracy,
            "total_pairs": total,
            "correct_predictions": correct
        }
        
        print(f"   📊 Results: {correct}/{total} ({accuracy:.2%})")

    # Calculate overall statistics
    if all_results["per_dataset"]:
        total_pairs_all = sum(r["total_pairs"] for r in all_results["per_dataset"].values())
        correct_all = sum(r["correct_predictions"] for r in all_results["per_dataset"].values())
        overall_accuracy = correct_all / total_pairs_all if total_pairs_all > 0 else 0.0
        
        all_results["overall"]["total_pairs"] = total_pairs_all
        all_results["overall"]["correct_predictions"] = correct_all
        all_results["overall"]["accuracy"] = overall_accuracy
    else:
        all_results["overall"]["total_pairs"] = 0
        all_results["overall"]["correct_predictions"] = 0
        all_results["overall"]["accuracy"] = 0.0

    # Save results
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, "robust_evaluation_results.json")
    
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(all_results, f, ensure_ascii=False, indent=2)
    
    # Print summary
    print(f"\n📊 Overall Results:")
    print(f"   Datasets evaluated: {len(all_results['per_dataset'])}")
    print(f"   Total pairs: {all_results['overall']['total_pairs']}")
    print(f"   Correct predictions: {all_results['overall']['correct_predictions']}")
    print(f"   Overall accuracy: {all_results['overall']['accuracy']:.2%}")
    
    print(f"\n📊 Per-dataset breakdown:")
    for dataset_name, results in all_results["per_dataset"].items():
        print(f"   {dataset_name}: {results['correct_predictions']}/{results['total_pairs']} ({results['accuracy']:.2%})")
    
    print(f"\n💾 Results saved to: {output_path}")
    
    return all_results


def main():
    parser = argparse.ArgumentParser(description="Evaluate reward model robustness on rewritten datasets")
    parser.add_argument("config", type=str, help="Path to YAML config file")
    args = parser.parse_args()

    cfg = read_yaml(args.config)

    model_path = cfg.get("model_name_or_path") or cfg.get("model_path")
    dataset_dir = cfg.get("dataset_dir")
    output_dir = cfg.get("output_dir") or model_path
    tokenizer_path = cfg.get("tokenizer_name_or_path")
    batch_size = int(cfg.get("per_device_eval_batch_size", 8))
    max_length = int(cfg.get("max_length", 2048))
    trust_remote_code = bool(cfg.get("trust_remote_code", False))
    torch_dtype = cfg.get("torch_dtype")
    attn_impl = cfg.get("attn_implementation")

    if not model_path:
        raise ValueError("model_name_or_path must be set in the config")
    if not dataset_dir:
        raise ValueError("dataset_dir must be set in the config")

    evaluate_robustness(
        model_path=model_path,
        dataset_dir=dataset_dir,
        output_dir=output_dir,
        tokenizer_path=tokenizer_path,
        batch_size=batch_size,
        max_length=max_length,
        trust_remote_code=trust_remote_code,
        torch_dtype=torch_dtype,
        attn_implementation=attn_impl,
    )


if __name__ == "__main__":
    main()
