import torch
import argparse
import json
import os
import re
import numpy as np
from tqdm import tqdm
from transformers import (
    LlamaTokenizer,
    AutoModelForSequenceClassification,
)

# --- Configuration ---
DEFAULT_RANK_MODEL_NAME = "weqweasdas/hh_rlhf_rm_open_llama_3b"
DEFAULT_MAX_LENGTH = 512 # Max length for RM, adjust if needed (max_instruction_length + max_generation from training)
# --- End Configuration ---


def convert_query_to_rm_query(query_string):
    """
    Converts a "query" string back to an "rm_query" string
    by extracting the instruction part.
    """

    # Define the known surrounding parts of the instruction in the query string
    prefix = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n"
    suffix = "\n\n### Response:"

    # --- Method 1: Using string slicing (if the format is strictly fixed) ---
    if query_string.startswith(prefix) and query_string.endswith(suffix):
        # Extract the instruction part
        start_index = len(prefix)
        # Calculate end_index by subtracting the length of the suffix from the total length
        end_index = len(query_string) - len(suffix)
        
        # Ensure start_index is not greater than end_index (e.g., if query_string is too short)
        if start_index <= end_index:
            instruction = query_string[start_index:end_index]
            
            # Construct the rm_query string
            rm_query = f"###Human: {instruction} ###Assistant: "
            return rm_query
        else:
            # This case implies the query_string is shorter than prefix + suffix,
            # or the instruction part is missing/negative length.
            raise ValueError("Query string is malformed or too short to extract instruction.")

    # --- Method 2: Using regular expressions (more robust to slight variations if needed,
    # but for this specific problem, string slicing is fine if the format is rigid) ---
    # The pattern looks for the text between the defined prefix and suffix.
    # re.escape is used to ensure any special regex characters in prefix/suffix are treated literally.
    # (.*?) is a non-greedy match for any characters (including newlines due to re.DOTALL).
    # pattern_regex = rf"^{re.escape(prefix)}(.*?){re.escape(suffix)}$" # Use ^ and $ for full match
    pattern_regex = rf"{re.escape(prefix)}(.*?){re.escape(suffix)}" # More flexible if there's leading/trailing whitespace around the whole query
    
    match = re.search(pattern_regex, query_string, re.DOTALL)
    if match:
        instruction = match.group(1) # This is equivalent to the original match.group(1).strip()
        rm_query = f"###Human: {instruction} ###Assistant: "
        return rm_query
    else:
        # If neither method works, the query string doesn't match the expected format.
        raise ValueError("Query string does not match the expected format to extract instruction.")


def load_rank_model_and_tokenizer(model_name_or_path: str):
    """Loads the Rank Model and its tokenizer."""
    print(f"Loading Rank Model: {model_name_or_path}")
    try:
        tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path, use_fast=False)
        if tokenizer.pad_token is None: # Important for Llama tokenizers
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id
            print(f"Set pad_token_id to eos_token_id: {tokenizer.eos_token_id}")

        model = AutoModelForSequenceClassification.from_pretrained(
            model_name_or_path,
            output_attentions=True,
            return_dict_in_generate=True,
            attn_implementation="eager",
            device_map="cuda:0",
        )
        print("Rank Model and Tokenizer loaded.")
        return model, tokenizer
    except Exception as e:
        print(f"Error loading Rank Model or Tokenizer: {e}")
        return None, None


def calculate_rewards(
    rank_model: AutoModelForSequenceClassification,
    rank_tokenizer: LlamaTokenizer,
    data: list, # List of {"prompt": str, "output": str}
    batch_size: int = 8,
    max_length: int = DEFAULT_MAX_LENGTH,
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
):
    """
    Calculates rewards for the generated outputs using the Rank Model.
    """
    rank_model.to(device)
    rank_model.eval()

    all_rewards = []
    print(f"Calculating rewards for {len(data)} samples on device: {device}...")

    for i in tqdm(range(0, len(data), batch_size)):
        batch_data = data[i:i+batch_size]
        
        # Prepare texts for the reward model: prompt + generated_output
        # In your training, it was `texts = [q + r for q, r in zip(batch["rm_query"], batch["response"])]`
        # Here, `item["prompt"]` is like `rm_query` and `item["output"]` is like `response`.
        # The prompt from inference usually ends with "Assistant:", so it should naturally lead into the output.
        texts_for_rm = []
        for item in batch_data:
            rm_query = convert_query_to_rm_query(item["instruction"])
            texts_for_rm.append(rm_query + item["output"])

        print(texts_for_rm[0])
        inputs = rank_tokenizer(
            texts_for_rm,
            return_tensors="pt",
            max_length=max_length,
            padding="max_length", # Or "longest"
            truncation=True,
        ).to(device)

        with torch.no_grad():
            outputs = rank_model(**inputs)
            # The reward is typically the raw logit from the sequence classification model
            rewards = outputs.logits.squeeze(-1) # Assuming single logit output
            all_rewards.extend(rewards.cpu().tolist())
            
    return all_rewards

def main_evaluation(
    inference_output_file: str,
    rank_model_name: str = DEFAULT_RANK_MODEL_NAME,
    batch_size: int = 8,
    max_length_rm: int = DEFAULT_MAX_LENGTH,
):
    if not os.path.exists(inference_output_file):
        print(f"Error: Inference output file not found: {inference_output_file}")
        return

    # 1. Load Inference Outputs
    print(f"Loading inference outputs from: {inference_output_file}")
    try:
        with open(inference_output_file, "r", encoding="utf-8") as f:
            inference_data = json.load(f)
        if not isinstance(inference_data, list) or \
           not all(isinstance(item, dict) and "instruction" in item and "output" in item for item in inference_data):
            print("Error: Invalid JSON format. Expected a list of {'prompt': ..., 'output': ...} dictionaries.")
            return
        print(f"Loaded {len(inference_data)} samples from inference output.")
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {inference_output_file}")
        return
    except Exception as e:
        print(f"Error reading inference output file: {e}")
        return

    if not inference_data:
        print("No data found in the inference output file.")
        return

    # 2. Load Rank Model and Tokenizer
    rank_model, rank_tokenizer = load_rank_model_and_tokenizer(rank_model_name)
    if rank_model is None or rank_tokenizer is None:
        return # Error handled in load_rank_model_and_tokenizer

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    rank_model.to(device)

    # 3. Calculate Rewards
    rewards = calculate_rewards(
        rank_model,
        rank_tokenizer,
        inference_data,
        batch_size=batch_size,
        max_length=max_length_rm,
        device=device
    )

    if not rewards:
        print("No rewards were calculated.")
        return

    # 4. Calculate and Print Statistics
    rewards_np = np.array(rewards)
    mean_reward = np.mean(rewards_np)
    std_reward = np.std(rewards_np)
    min_reward = np.min(rewards_np)
    max_reward = np.max(rewards_np)
    median_reward = np.median(rewards_np)
    
    # Calculate percentile ranks
    percentiles = [10, 25, 50, 75, 90]
    percentile_values = np.percentile(rewards_np, percentiles)


    print("\n--- Evaluation Statistics ---")
    print(f"Total samples evaluated: {len(rewards_np)}")
    print(f"Mean Reward:            {mean_reward:.4f}")
    print(f"Median Reward:          {median_reward:.4f}")
    print(f"Standard Deviation:     {std_reward:.4f}")
    print(f"Min Reward:             {min_reward:.4f}")
    print(f"Max Reward:             {max_reward:.4f}")
    
    print("\nPercentile Ranks for Rewards:")
    for p, v in zip(percentiles, percentile_values):
        print(f"  {p}th percentile:       {v:.4f}")

    # You can save these statistics to a file if needed
    # stats_output_file = inference_output_file.replace(".json", "_eval_stats.json")
    # with open(stats_output_file, "w") as f_stats:
    #     json.dump({
    #         "mean_reward": mean_reward,
    #         "std_reward": std_reward,
    #         "min_reward": min_reward,
    #         "max_reward": max_reward,
    #         "median_reward": median_reward,
    #         "num_samples": len(rewards_np),
    #         "percentiles": {str(p): v for p, v in zip(percentiles, percentile_values)}
    #     }, f_stats, indent=4)
    # print(f"\nEvaluation statistics saved to: {stats_output_file}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate generated outputs using a Rank Model.")
    parser.add_argument(
        "inference_output_file",
        type=str,
        help="Path to the JSON file containing inference outputs (list of {'prompt': ..., 'output': ...})."
    )
    parser.add_argument(
        "--rank_model_name",
        type=str,
        default=DEFAULT_RANK_MODEL_NAME,
        help="Name or path of the Rank Model (Reward Model) to use for evaluation."
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=8,
        help="Batch size for processing samples through the Rank Model."
    )
    parser.add_argument(
        "--max_length_rm",
        type=int,
        default=DEFAULT_MAX_LENGTH,
        help="Maximum sequence length for the Rank Model's tokenizer."
    )

    args = parser.parse_args()

    main_evaluation(
        inference_output_file=args.inference_output_file,
        rank_model_name=args.rank_model_name,
        batch_size=args.batch_size,
        max_length_rm=args.max_length_rm,
    )