import re
import json
import argparse
from pathlib import Path
from typing import Dict, List, Optional, Union

import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset
warnings.filterwarnings("ignore")

from labeling_functions.bias_judge import BiasEvaluator
from labeling_functions.safe_judge import SafetyEvaluator
from labeling_functions.factual_judge import FactualityEvaluator
from labeling_functions.relevance_judge import RelevanceEvaluator
from labeling_functions.structure_judge import StructuralEvaluator
from labeling_functions.readability_judge import ReadabilityEvaluator

class NumpyEncoder(json.JSONEncoder):
    """Custom JSON encoder for NumPy types."""
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super().default(obj)

class DataLoader:
    """Handles loading and preprocessing of training data."""
    
    @staticmethod
    def load_judgeLM(score_diff_threshold: int = 3) -> pd.DataFrame:
        """Load and filter JudgeLM training data."""
        data = pd.read_json("./llm_as_judge_data/judgeLM/judgelm_train_100k.jsonl", lines=True)
        data["winner_index"] = data["score"].apply(lambda x: 0 if x[0] > x[1] else 1)
        data["difference"] = data["score"].apply(lambda x: abs(x[0] - x[1]))
        return data[data["difference"] >= score_diff_threshold].reset_index(drop=True)

    @staticmethod
    def load_pandaLM() -> pd.DataFrame:
        """Load and preprocess PandaLM training data."""
        with open('./llm_as_judge_data/pandaLM/train.json', 'r') as file:
            data = json.load(file)

        patterns = {
            "instruction": r"### Instruction:\n(.*?)\n\n",
            "response1": r"### Response 1:\n(.*?)\n\n",
            "response2": r"### Response 2:\n(.*?)\n\n",
            "reference": r"### Reference: (.*?)$"
        }

        rows = []
        for idx, item in enumerate(data):
            input_text = item["input_sequence"]
            output_text = item["output_sequence"]
            if output_text[0] == 'T':
                continue
                
            winner = int(output_text[0]) - 1
            row = [
                idx,
                re.search(patterns["instruction"], input_text, re.DOTALL).group(1),
                re.search(patterns["response1"], input_text, re.DOTALL).group(1),
                re.search(patterns["response2"], input_text, re.DOTALL).group(1),
                winner
            ]
            rows.append(row)

        return pd.DataFrame(
            rows,
            columns=["question_id", "question_body", "answer1_body", "answer2_body", "winner_index"]
        )

    @staticmethod
    def load_shp(score_diff_threshold: int = 10) -> pd.DataFrame:
        """Load and preprocess SHP training data."""

        data = pd.DataFrame(load_dataset("stanfordnlp/shp")["train"])
        data["winner_index"] = 1 - data["labels"]
        data["score_diff"] = data.apply(lambda row: abs(row['score_A'] - row['score_B']), axis=1).tolist()
        selected_data = data[data["score_diff"] >= score_diff_threshold].reset_index()
        selected_data = selected_data[["index", "domain", "history", "human_ref_A", "human_ref_B", "winner_index"]]
        selected_data.columns = ["question_id", "domain", "question_body", "answer1_body", "answer2_body", "winner_index"]
        return selected_data

    """
    @staticmethod
    def load_rlhf() -> pd.DataFrame:
        # Load and preprocess RLHF-single-round training data.

        data = pd.DataFrame(load_dataset("beyond/rlhf-reward-single-round")["train"]).reset_index()
        data["winner_index"] = 0
        data.columns = ["question_id", "question_body", "answer1_body", "answer2_body", "winner_index"]
        return data
    """

    @staticmethod
    def load_prometheus(score_diff_threshold: int = 3) -> pd.DataFrame:
        """Load and preprocess prometheus training data."""

        prometheus_dataset = load_dataset("prometheus-eval/Preference-Collection")["train"]
        
        selected_index_collection = []
        for index in range(len(prometheus_dataset)):
            row_sample = prometheus_dataset[index]
            if abs(int(row_sample["orig_score_A"]) - int(row_sample["orig_score_B"])) >= score_diff_threshold:
                selected_index_collection.append(index)

        selected_rows = []
        for selected_index in selected_index_collection:    
            row_sample = prometheus_dataset[selected_index]
            question = row_sample["orig_instruction"]
            response1 = row_sample["orig_response_A"]
            response2 = row_sample["orig_response_B"]
            gt_winner = row_sample["orig_preference"]
            if gt_winner == "A": 
                gt_winner = 0
            else:
                gt_winner = 1
            selected_rows.append([selected_index, question, response1, response2, gt_winner])
        
        return pd.DataFrame(
            selected_rows,
            columns=["question_id", "question_body", "answer1_body", "answer2_body", "winner_index"]
        )   

class ResponseEvaluator:
    """Evaluates responses using multiple judges."""
    
    METRIC_PREFIXES = {
        'safety': 'safety_',
        'structure': 'structure_',
        'bias': 'bias_',
        'relevance': 'relevance_',
        'readability': 'readability_',
        'factuality': 'factuality_'
    }

    def __init__(self):
        self.evaluators = {
            'safety': SafetyEvaluator(),
            'structure': StructuralEvaluator(),
            'bias': BiasEvaluator(),
            'relevance': RelevanceEvaluator(),
            'readability': ReadabilityEvaluator(),
            'factuality': FactualityEvaluator()
        }

    def evaluate_row(self, row: pd.Series) -> Dict[str, Union[str, int, List]]:
        """Evaluate a single row of data."""
        question = row["question_body"]
        response1 = row["answer1_body"]
        response2 = row["answer2_body"]
        
        tables = {
            'safety': pd.DataFrame(self.evaluators['safety'].evaluate(response1, response2)),
            'structure': pd.DataFrame(self.evaluators['structure'].evaluate(response1, response2)),
            'bias': pd.DataFrame(self.evaluators['bias'].evaluate(response1, response2)),
            'relevance': pd.DataFrame(self.evaluators['relevance'].evaluate(question, response1, response2)),
            'readability': pd.DataFrame(self.evaluators['readability'].evaluate(question, response1, response2)),
            'factuality': pd.DataFrame(self.evaluators['factuality'].evaluate(question, response1, response2))
        }

        for key, table in tables.items():
            table.index = [f"{self.METRIC_PREFIXES[key]}{j}" for j in table.index]

        combined_table = pd.concat(tables.values())
        weak_labels = np.select(
            [combined_table['response1'] > combined_table['response2'],
             combined_table['response1'] < combined_table['response2']],
            [0, 1],
            default=-1
        ).tolist()

        return {
            "question_id": row["question_id"],
            "question": question,
            "response1_text": response1,
            "response2_text": response2,
            "winner_index": int(row["winner_index"]),
            "judging_results": json.loads(combined_table.to_json()),
            "weak_labels": weak_labels
        }

def main():
    parser = argparse.ArgumentParser(description="Evaluate responses from JudgeLM, PandaLM, SHP, Prometheus datasets")
    parser.add_argument(
        "--dataset",
        choices=["judgeLM", "pandaLM", "shp", "prometheus"],
        required=True,
        help="Dataset to process (judgeLM, pandaLM, shp, prometheus)"
    )
    parser.add_argument(
        "--score-diff",
        type=int,
        default=3,
        help="Score difference threshold (default: 3)"
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="program_outputs",
        help="Directory to save output JSON (default: results)"
    )
    
    args = parser.parse_args()

    # Load data
    loader = DataLoader()
    if args.dataset == "judgeLM":
        training_data = loader.load_judgeLM(args.score_diff)
        dataset_name = f"judgeLM_training_data_diff>={args.score_diff}"
    elif args.dataset == "pandaLM":
        training_data = loader.load_pandaLM()
        dataset_name = "pandaLM_training_data"
    elif args.dataset == "shp":
        training_data = loader.load_shp(args.score_diff)
        dataset_name = f"shp_training_data_diff>={args.score_diff}"
    elif args.dataset == "prometheus":
        training_data = loader.load_prometheus(args.score_diff)
        dataset_name = f"prometheus_training_data_diff>={args.score_diff}"
        
    # Initialize evaluator
    evaluator = ResponseEvaluator()
    results = {}

    # Evaluate responses
    for idx in tqdm(range(len(training_data)), desc="Evaluating responses", total=len(training_data)):

        try:
            row = training_data.iloc[idx]
            row_result = evaluator.evaluate_row(row)
            row_result["dataset"] = dataset_name
            if "domain" in row:
                row_result["domain"] = row["domain"]
            results[idx] = row_result
        except Exception as e:
            print(f"Error at index {idx}: {e}")
            continue

    # Save results
    output_dir = Path(args.output_dir)
    output_path = output_dir / f"{args.dataset}_outputs.json"
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(results, f, cls=NumpyEncoder, indent=2)
    
    print(f"Results saved to {output_path}")

if __name__ == "__main__":
    main()