# score evaluation results.
# 
# example launch command:
#     python axbench/scripts/evaluate.py --config axbench/demo/sweep/evaluate.yaml --mode latent

import shutil
from axbench.models.language_models import (
    LanguageModel
)

import os, argparse, yaml, json, glob, pickle, tempfile, copy
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import torch
from pathlib import Path
import numpy as np
from openai import AsyncOpenAI
import httpx, asyncio
import datetime
import yaml
from axbench.scripts.inference import LATENT_EXCLUDE_MODELS, STEERING_EXCLUDE_MODELS
import axbench
from axbench.utils.plot_utils import (
    plot_aggregated_roc, 
    plot_metrics,
    plot_accuracy_bars,
    plot_win_rates,
    plot_metrics_multiple_datasets,
)
from axbench.templates.html_templates import (
    generate_html_with_highlight_text,
)
from axbench.scripts.args.eval_args import EvalArgs
from functools import partial
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import multiprocessing
from axbench.utils.constants import *

import logging
logging.basicConfig(format='%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
    datefmt='%Y-%m-%d:%H:%M:%S',
    level=logging.WARN)
logger = logging.getLogger(__name__)

STATE_FILE = "evaluate_state.pkl"


def data_generator(data_dir, mode, winrate_split_ratio=None):
    """
    Generator function to read data files and yield data subsets by group_id.
    Pre-loads data in chunks to reduce I/O bottlenecks.

    Args:
        data_dir (str): Path to the data directory.
        mode (str): Mode of operation ('latent' or 'steering').

    Yields:
        (group_id, df_subset): A tuple containing the group_id and subset DataFrame.
    """
    # Pre-load and organize data by concept_id
    concept_data = {}
    if mode == "latent":
        df = pd.read_parquet(os.path.join(data_dir, f'latent_data.parquet'))
    elif "steering" in mode or mode == "winrate":
        df = pd.read_parquet(os.path.join(data_dir, f'steering_data.parquet'))
    elif mode == "train_data":
        df = pd.read_parquet(os.path.join(data_dir, f'dpo_train_data.parquet'))
    # Group by concept_id and store in dictionary
    for concept_id, group in df.groupby('concept_id'):
        if concept_id not in concept_data:
            concept_data[concept_id] = []
        concept_data[concept_id].append(group)
    
    # Yield concatenated data for each concept_id
    for concept_id in sorted(concept_data.keys()):
        if len(concept_data[concept_id]) > 1:
            df_subset = pd.concat(concept_data[concept_id])
        else:
            df_subset = concept_data[concept_id][0]
        winrate_split_ratio = None
        if winrate_split_ratio is not None:
            n_input_ids = df_subset["input_id"].max()+1
            n_steering_ids = n_input_ids - round(n_input_ids * winrate_split_ratio)
            if mode == "steering":
                df_subset = df_subset[df_subset["input_id"] < n_steering_ids]
            elif mode == "steering_test" or mode == "winrate":
                df_subset = df_subset[df_subset["input_id"] >= n_steering_ids]
        yield (concept_id, df_subset)


def get_best_factors(aggregated_results):
    best_factors = {}
    for result in aggregated_results:
        best_factors[result["concept_id"]] = {}
        for method, scores in result["results"]["LMJudgeEvaluator"].items():
            best_factors[result["concept_id"]][method] = scores["factor"][np.argmax(scores["lm_judge_rating"])]
    return best_factors


def winrate_data_generator(data_dir, aggregated_results, winrate_split_ratio):
    best_factors = get_best_factors(aggregated_results)
    df_generator = data_generator(data_dir, mode="winrate", winrate_split_ratio=winrate_split_ratio)
    for concept_id, current_df in df_generator:
        # if concept_id >= start_concept_id: # TODO: uncomment this when we fix our pipeline
        concept_best_dfs = {}
        for method, factor in best_factors[concept_id].items():
            include_columns = ["concept_id", "input_concept", "input_id", "original_prompt", "steered_input", "factor", f"{method}_steered_generation"]
            method_df = current_df[include_columns]
            method_best_df = method_df[method_df["factor"]==factor]
            concept_best_dfs[method] = method_best_df.copy()
            concept_best_df = method_best_df[["concept_id", "input_concept", "input_id", "original_prompt", "steered_input"]].copy()
        for method in best_factors[concept_id].keys():
            # Use merge instead of direct assignment to ensure proper alignment
            concept_best_df = concept_best_df.merge(
                concept_best_dfs[method][['concept_id', 'input_concept', 'input_id', f"{method}_steered_generation"]],
                on=['concept_id', 'input_concept', 'input_id'],
                how='left')
        yield (concept_id, concept_best_df)


def save_results(dump_dir, state, concept_id, partition, eval_results, eval_df=None):
    """
    Save the results dictionary to a .jsonl file.
    Each line in the file represents one concept_id's evaluation results.
    """
    # handle training df first
    dump_dir.mkdir(parents=True, exist_ok=True)
    
    # Save state
    state_path = os.path.join(dump_dir, f"{partition}_{STATE_FILE}")
    with open(state_path, "wb") as f:
        pickle.dump(state, f)
    
    # Define the output file path for JSON Lines
    result_path = Path(dump_dir) / f"{partition}.jsonl"
    result_entry = {
        "concept_id": int(concept_id),
        "results": eval_results
    }
    with open(result_path, "a") as f:
        f.write(json.dumps(result_entry) + "\n")

    # save the steering ratings for each example
    if eval_df is not None:
        sorted_evaluator_names = sorted(list(eval_df.keys()))
        # print(sorted_evaluator_names)
        sorted_model_names = sorted(list(eval_df[sorted_evaluator_names[0]].keys()))
        if len(sorted_evaluator_names) == 0:
            return
        current_df = eval_df[sorted_evaluator_names[0]][sorted_model_names[0]].copy()
        for evaluator_name in sorted_evaluator_names:
            if evaluator_name == "PerplexityEvaluator":
                continue
            for model_name in sorted_model_names:
                if evaluator_name == sorted_evaluator_names[0] and model_name == sorted_model_names[0]:
                    continue
                if evaluator_name == "LMJudgeEvaluator":
                    current_df[f"{model_name}_{evaluator_name}"] = eval_df[evaluator_name][model_name][f"{model_name}_{evaluator_name}"]

                    current_df[f"{model_name}_{evaluator_name}_relevance_concept_ratings"] = \
                        eval_df[evaluator_name][model_name][f"{model_name}_{evaluator_name}_relevance_concept_ratings"]
                    current_df[f"{model_name}_{evaluator_name}_relevance_concept_completions"] = \
                        eval_df[evaluator_name][model_name][f"{model_name}_{evaluator_name}_relevance_concept_completions"]
                    
                    current_df[f"{model_name}_{evaluator_name}_relevance_instruction_ratings"] = \
                        eval_df[evaluator_name][model_name][f"{model_name}_{evaluator_name}_relevance_instruction_ratings"]
                    current_df[f"{model_name}_{evaluator_name}_relevance_instruction_completions"] = \
                        eval_df[evaluator_name][model_name][f"{model_name}_{evaluator_name}_relevance_instruction_completions"]
                    
                    current_df[f"{model_name}_{evaluator_name}_fluency_ratings"] = \
                        eval_df[evaluator_name][model_name][f"{model_name}_{evaluator_name}_fluency_ratings"]
                    current_df[f"{model_name}_{evaluator_name}_fluency_completions"] = \
                        eval_df[evaluator_name][model_name][f"{model_name}_{evaluator_name}_fluency_completions"]
                else:
                    try:
                        current_df[f"{model_name}_{evaluator_name}_rule_following"] = \
                            eval_df[evaluator_name][model_name][f"{model_name}_{evaluator_name}_rule_following"]
                    except:
                        current_df[f"{model_name}_{evaluator_name}_rule_following_winning"] = \
                            eval_df[evaluator_name][model_name][f"{evaluator_name}_rule_following_winning"]
                        current_df[f"{model_name}_{evaluator_name}_rule_following_losing"] = \
                            eval_df[evaluator_name][model_name][f"{evaluator_name}_rule_following_losing"]
                    
        df_path = os.path.join(dump_dir, f"{partition}_data.parquet")
        if os.path.exists(df_path):
            existing_df = pd.read_parquet(df_path)
            combined_df = pd.concat([existing_df, current_df], ignore_index=True)
        else:
            combined_df = current_df
        combined_df.to_parquet(df_path, index=False)


def load_state(dump_dir, mode):
    """
    Load the state from a file if it exists.
    
    Args:
        dump_dir (str): The directory to load the state file from.
    
    Returns:
        dict: The loaded state dictionary, or None if no state file exists.
    """
    state_path = os.path.join(f"{dump_dir}", f"{mode}_{STATE_FILE}")
    if os.path.exists(state_path):
        with open(state_path, "rb") as f:
            return pickle.load(f)
    return None


def combine_scores_per_concept(concept_data):
    """Combine scores from concept and following evaluators for each method."""
    return concept_data["results"]["LMJudgeEvaluator"]


def process_jsonl_file(jsonl_lines):
    for data in jsonl_lines:
        data["results"]["LMJudgeEvaluator"] = \
            combine_scores_per_concept(data)
    return jsonl_lines


def plot_steering(aggregated_results, dump_dir, report_to=[], wandb_name=None, mode=None, rule = False):
    try:
        configs = [
            {
                'evaluator_name': 'PerplexityEvaluator',
                'metric_name': 'perplexity',
                'y_label': 'Perplexity',
                'use_log_scale': False
            },
            {
                'evaluator_name': 'LMJudgeEvaluator',
                'metric_name': 'relevance_concept_ratings',
                'y_label': 'Concept',
                'use_log_scale': False
            },
            {
                'evaluator_name': 'LMJudgeEvaluator',
                'metric_name': 'relevance_instruction_ratings',
                'y_label': 'Instruct',
                'use_log_scale': False
            },
            {
                'evaluator_name': 'LMJudgeEvaluator',
                'metric_name': 'fluency_ratings',
                'y_label': 'Fluency',
                'use_log_scale': False
            },
            {
                'evaluator_name': 'LMJudgeEvaluator',
                'metric_name': 'lm_judge_rating',
                'y_label': 'Aggregated',
                'use_log_scale': False
            },
            {
                'evaluator_name': 'PerplexityEvaluator',
                'metric_name': 'strength',
                'y_label': 'Strength',
                'use_log_scale': False
            },
            {
                'evaluator_name': 'RuleEvaluator',
                'metric_name': 'rule_following',
                'y_label': 'Rule',
                'use_log_scale': False
            }
        ]
        plot_metrics(
            jsonl_data=aggregated_results,
            configs=configs,
            write_to_path=dump_dir, 
            report_to=report_to,
            wandb_name=wandb_name,
            mode=mode
        )
        plot_metrics_multiple_datasets(
            data_path = os.path.join(dump_dir, f"{mode}_data.parquet"),
            write_to_path=dump_dir, 
            report_to=report_to,
            wandb_name=wandb_name,
            mode=mode,
            rule=rule
        )
    except Exception as e:
        logger.warning(f"Failed to plot: {e}")


def eval_steering_single_task(args_tuple):
    """Helper function to evaluate a single concept-model-evaluator combination"""
    concept_id, current_df, evaluator_name, model_name, dump_dir, \
        lm_model, winrate_baseline, lm_caches, steer_dataset_type = args_tuple
    
    # Create LanguageModel instance within the worker process
    client = AsyncOpenAI(
        api_key=os.environ.get("OPENAI_API_KEY"),
        timeout=60.0,
        http_client=httpx.AsyncClient(
            limits=httpx.Limits(
                max_keepalive_connections=100, 
                max_connections=1000
            ),
            headers={"Connection": "close"},
        ),
        max_retries=3,
    )
    lm_model = LanguageModel(
        lm_model,
        client,
        dump_dir=dump_dir,
        use_cache=True,
        cache_level="prompt",
        cache_tag="evaluate",
        master_data_dir="axbench/data",
        temperature=0.0
    )
    # overwrite cache if any.
    if bool(lm_caches):
        lm_model.cache_in_mem = lm_caches
    
    try:
        evaluator_class = getattr(axbench, evaluator_name)
        evaluator = evaluator_class(
            model_name, dump_dir=dump_dir, 
            concept_id=concept_id, lm_model=lm_model, winrate_baseline=winrate_baseline, steer_dataset_type=steer_dataset_type)
        if "Rule" in evaluator_name and "winning_output" not in current_df.columns:
            eval_result = evaluator.compute_metrics(current_df, rule_type = CONCEPT_TO_RULE[current_df['input_concept'].iloc[0]])
        elif "Rule" in evaluator_name and "winning_output" in current_df.columns:
            eval_result = evaluator.compute_metrics_train(current_df, rule_type = CONCEPT_TO_RULE[current_df['output_concept'].iloc[0]])
        else:
            eval_result = evaluator.compute_metrics(current_df)
        return (concept_id, evaluator.__str__(), model_name.__str__(), eval_result, \
                lm_model.stats.get_report(), None if bool(lm_caches) else lm_model.cache_in_mem, current_df)
    finally:
        # Properly close both the HTTP client and async client
        async def cleanup():
            await client.close()
        asyncio.run(cleanup())


def eval_steering(args):
    """
    Evaluate steering performance using multi-processing for all tasks
    """
    data_dir = args.data_dir
    dump_dir = args.dump_dir

    # Initialize data generator
    df_generator = data_generator(
        args.data_dir, mode=args.mode, 
        winrate_split_ratio=args.winrate_split_ratio)

    # Load previous state if exists
    state = load_state(args.dump_dir, mode=args.mode)
    start_concept_id = state.get("concept_id", 0) if state else 0
    logger.warning(f"Starting concept_id: {start_concept_id}")

    if "PromptSteering" in args.models and len(args.defense) > 0:
        # Remove PromptSteering from the models list
        args.models = [model for model in args.models if model != "PromptSteering"]
        # Add PromptSteering_d for each defense method
        args.models.extend([f"PromptSteering_{d}" for d in args.defense])

    # Create all evaluation tasks - flattened for maximum parallelization
    all_tasks = [
        (concept_id, current_df, evaluator_name, model_name, args.dump_dir, \
         args.lm_model, args.winrate_baseline, {}, args.steer_data_type)
        for concept_id, current_df in df_generator
        if concept_id >= start_concept_id
        for evaluator_name in args.steering_evaluators
        for model_name in args.models
        if model_name not in STEERING_EXCLUDE_MODELS
    ]

    # Group results by concept_id
    all_results = {}
    
    # Run all evaluations with process pool
    logger.warning(f"Number of workers: {args.num_of_workers}; Number of CPUs: {multiprocessing.cpu_count()}")
    if not hasattr(args, 'num_of_workers') or args.num_of_workers is None:
        args.num_of_workers = max(1, multiprocessing.cpu_count() - 1)
    lm_reports = []
    eval_dfs = {}
    lm_caches = {}
    with ProcessPoolExecutor(max_workers=args.num_of_workers) as executor:
        for concept_id, evaluator_str, model_str, result, lm_report, lm_cache, current_df in executor.map(
            eval_steering_single_task, all_tasks):
            if concept_id not in all_results:
                all_results[concept_id] = {}
                eval_dfs[concept_id] = {}
            if evaluator_str not in all_results[concept_id]:
                all_results[concept_id][evaluator_str] = {}
                eval_dfs[concept_id][evaluator_str] = {}
            all_results[concept_id][evaluator_str][model_str] = result
            if ("raw_relevance_concept_ratings" in result or \
                "raw_relevance_instruction_ratings" in result or \
                "raw_fluency_ratings" in result or \
                "raw_aggregated_ratings" in result) and evaluator_str == "LMJudgeEvaluator":
                current_df[f"{model_str}_{evaluator_str}_relevance_concept_ratings"] = result["raw_relevance_concept_ratings"]
                current_df[f"{model_str}_{evaluator_str}_relevance_instruction_ratings"] = result["raw_relevance_instruction_ratings"]
                current_df[f"{model_str}_{evaluator_str}_fluency_ratings"] = result["raw_fluency_ratings"]
                current_df[f"{model_str}_{evaluator_str}"] = result["raw_aggregated_ratings"]
                current_df[f"{model_str}_{evaluator_str}_relevance_concept_completions"] = result["relevance_concept_completions"]
                current_df[f"{model_str}_{evaluator_str}_relevance_instruction_completions"] = result["relevance_instruction_completions"]
                current_df[f"{model_str}_{evaluator_str}_fluency_completions"] = result["fluency_completions"]              
                eval_dfs[concept_id][evaluator_str][model_str] = current_df.copy()

            if "rule_following" in result and evaluator_str == "RuleEvaluator" and args.mode != "train_data":
                current_df[f"{model_str}_{evaluator_str}_rule_following"] = result["raw_rule_following"]
                eval_dfs[concept_id][evaluator_str][model_str] = current_df.copy()
            elif "rule_following" in result and evaluator_str == "RuleEvaluator" and args.mode == "train_data":
                current_df[f"{evaluator_str}_rule_following_winning"] = result["raw_rule_following_winning"]
                current_df[f"{evaluator_str}_rule_following_losing"] = result["raw_rule_following_losing"]
                eval_dfs[concept_id][evaluator_str][model_str] = current_df.copy()

            
            lm_reports += [lm_report]
            lm_caches.update(lm_cache)
            logger.warning(f"Completed task for concept_id: {concept_id}, model: {model_str}, evaluator: {evaluator_str}")

    # Batch save all results
    for concept_id, eval_results in sorted(all_results.items()):
        save_results(
            dump_dir, 
            {"concept_id": concept_id + 1}, 
            concept_id, 
            args.mode, 
            eval_results, 
            eval_dfs[concept_id]
        )
    if args.mode == "train_data":
        return
    # Reload for plotting and optional winrate
    try:
        aggregated_results = process_jsonl_file(
            load_jsonl(os.path.join(dump_dir / f'{args.mode}.jsonl')))
    except Exception as e:
        logger.warning(f"Failed to load {args.mode}.jsonl: {e}. Aborting evaluation.")
        return

    # Aggregate LM reports
    aggregated_lm_report = {
        "total_calls": sum([report["total_calls"] for report in lm_reports]),
        "total_cache_hits": sum([report["total_cache_hits"] for report in lm_reports]),
        "total_price": sum([report["total_price"] for report in lm_reports])
    }
    logger.warning("="*20)  
    logger.warning(f"Total calls: {aggregated_lm_report['total_calls']}, "
                   f"Total cache hits: {aggregated_lm_report['total_cache_hits']}")
    logger.warning(f"Total price: ${aggregated_lm_report['total_price']}")
    logger.warning("="*20)

    # Generate final plot
    logger.warning("Generating final plot...")
    plot_steering(aggregated_results, dump_dir, args.report_to, args.wandb_name, args.mode)
    plot_metrics_multiple_datasets(os.path.join(dump_dir, "steering_data.parquet"), dump_dir, args.report_to, args.wandb_name, args.mode, rule = args.steer_data_type=="rule")
    logger.warning("Evaluation completed!")


def load_jsonl(jsonl_path):
    """
    Load data from a JSON lines file.
    """
    jsonl_data = []
    with open(jsonl_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            jsonl_data += [data]
    return jsonl_data
    

def plot_latent(dump_dir, report_to=[], wandb_name=None):
    # aggregate all results
    aggregated_results = load_jsonl(os.path.join(dump_dir, 'latent.jsonl'))
    plot_aggregated_roc(
        aggregated_results, write_to_path=dump_dir, report_to=report_to, wandb_name=wandb_name)
    plot_accuracy_bars(
        aggregated_results, "HardNegativeEvaluator", write_to_path=dump_dir, 
        report_to=report_to, wandb_name=wandb_name)


def eval_latent(args):

    data_dir = args.data_dir
    dump_dir = args.dump_dir
    latent_data_path = os.path.join(data_dir, f'latent_data.parquet')
    if not os.path.exists(latent_data_path):
        logger.warning(f"Latent data not found at {latent_data_path}")
        return
    df_generator = data_generator(args.data_dir, mode="latent")

    state = load_state(args.dump_dir, mode="latent")
    start_concept_id = state.get("concept_id", 0) if state else 0
    logger.warning(f"Starting concept_id: {start_concept_id}")

    for concept_id, current_df in df_generator:
        if concept_id < start_concept_id:
            continue
        logger.warning(f"Evaluating concept_id: {concept_id}")
        
        # Initialize a dictionary for storing evaluation results for this `concept_id`
        eval_results = {}
        for model_name in args.models:
            if model_name in LATENT_EXCLUDE_MODELS:
                continue
            for evaluator_name in args.latent_evaluators:
                evaluator_class = getattr(axbench, evaluator_name)
                evaluator = evaluator_class(model_name)
                # Call each evaluator and store results
                eval_result = evaluator.compute_metrics(current_df)
                if evaluator.__str__() not in eval_results:
                    eval_results[evaluator.__str__()] = {}
                eval_results[evaluator.__str__()][model_name.__str__()] = eval_result
        save_results(
            dump_dir, {"concept_id": concept_id + 1}, 
            concept_id, 'latent', eval_results, None)

    # Generate final plot
    logger.warning("Generating final plot...")
    plot_latent(dump_dir, args.report_to, args.wandb_name)
    logger.warning("Evaluation completed!")


def main():
    custom_args = [
        {
            'args': ['--mode'],
            'kwargs': {
                'type': str,
                'default': "all",
                'help': 'The evaluation mode.'
            }
        },

    ]
    args = EvalArgs(custom_args=custom_args, section="evaluate", ignore_unknown=True)
    if args.mode == "train_data":
        args.data_dir = f"{args.dump_dir}/generate" if args.overwrite_inference_dump_dir is None else Path(args.overwrite_inference_dump_dir)
    else:
        args.data_dir = f"{args.dump_dir}/inference" if args.overwrite_inference_dump_dir is None else Path(args.overwrite_inference_dump_dir)
    logger.warning("Evaluating generations with the following configuration:")
    logger.warning(args)

    dump_dir = Path(args.dump_dir) / "evaluate" if args.overwrite_evaluate_dump_dir is None else Path(args.overwrite_evaluate_dump_dir)

    dump_dir.mkdir(parents=True, exist_ok=True)
    args.dump_dir = dump_dir

    if args.mode == "latent":
        eval_latent(args)
    elif "steering" in args.mode: # steering or steering_test
        eval_steering(args)
    elif args.mode == "train_data":
        eval_steering(args)
    elif args.mode == "all":
        eval_latent(args)
        eval_steering(args)

if __name__ == "__main__":
    main()

