# flake8: noqa: E402
from minimal.event_loop import fix_asyncio

fix_asyncio()

import asyncio
import itertools
import logging
import math
import random
import typing as T
from collections import defaultdict
from datetime import datetime, timezone

import llama_index.core.instrumentation as instrument
import numpy as np
from llama_index.core.bridge.pydantic import Field
from llama_index.core.evaluation import CorrectnessEvaluator
from llama_index.core.evaluation.base import EvaluationResult
from llama_index.core.llms import CompletionResponse
from llama_index.core.llms.function_calling import FunctionCallingLLM

import minimal.core as core
from minimal.configuration import cfg
from minimal.flows import Flow
from minimal.instrumentation.tokens import LLMCallData
from minimal.llm import (
    ANTHROPIC_CLAUDE_HAIKU_35,  # noqa: F401
    ANTHROPIC_CLAUDE_SONNET_35,
    AZURE_GPT4O_MINI,  # noqa: F401
    GCP_GEMINI_FLASH,  # noqa: F401
    GCP_GEMINI_FLASH2,  # noqa: F401
    GCP_GEMINI_PRO,  # noqa: F401
    AZURE_o3_MINI,  # noqa: F401
)

logging.basicConfig(level=logging.WARNING)
log = logging.getLogger()
log.setLevel(logging.INFO)

dispatcher = instrument.get_dispatcher()
CorrectnessEvaluator.evaluate = dispatcher.span(CorrectnessEvaluator.evaluate)  # type: ignore
CorrectnessEvaluator.aevaluate = dispatcher.span(CorrectnessEvaluator.aevaluate)  # type: ignore

EVAL_LLMS = [
    GCP_GEMINI_FLASH,
    GCP_GEMINI_FLASH2,
    ANTHROPIC_CLAUDE_HAIKU_35,
    ANTHROPIC_CLAUDE_SONNET_35,
]


class FlowgenEvaluationResult(EvaluationResult):
    class Config:
        arbitrary_types_allowed = True

    qa_pair: T.Optional[core.QAPair] = Field(default=None, description="Q&A pair")
    run_time: T.Optional[float] = Field(
        default=np.nan, description="Flow completion time"
    )
    generation_exception: T.Optional[Exception] = Field(
        default=None, description="Exception during generation"
    )
    evaluation_exception: T.Optional[Exception] = Field(
        default=None, description="Exception during evaluation"
    )
    llm_call_data: T.List[LLMCallData] = Field(
        default_factory=list,
        description="Token counts and latencies for all LLM calls made during flow",
    )


async def agenerate_pair(
    qa_pair: core.QAPair,
    flow: Flow,
) -> T.Tuple[CompletionResponse, float, T.List[LLMCallData]]:
    """Get flow's answer to a question from an Q&A pair asynchronously."""
    # random wait to avoid thundering herd
    return await flow.agenerate(
        query=qa_pair.question,
    )


async def aevaluate_pair(
    qa_pair: core.QAPair,
    response: CompletionResponse,
    evaluator: CorrectnessEvaluator,
) -> EvaluationResult:
    """Evaluate a flow response asynchronously."""
    return await evaluator.aevaluate(
        return_values_on_exception=(None,),
        query=qa_pair.question,
        response=response.text,
        reference=qa_pair.answer,
    )


async def _aeval_pair_debias(
    qa_pair: core.QAPair,
    flow: Flow,
    evaluators: T.List[CorrectnessEvaluator],
) -> FlowgenEvaluationResult:
    """Evaluate single Q&A item asynchronously with an evaluator chosen at random."""
    response, run_time, call_data = await agenerate_pair(qa_pair, flow)
    eval_result = None
    if response:
        evaluator = random.choice(evaluators)
        eval_result = await aevaluate_pair(qa_pair, response, evaluator)

    return FlowgenEvaluationResult(
        qa_pair=qa_pair,
        run_time=run_time,
        llm_call_data=call_data,
        **(eval_result.model_dump() if eval_result else {}),
    )


async def _aeval_all_pair_runner(
    pair_eval_runner: T.Callable,
    dataset: T.List[core.QAPair],
    flow: Flow,
    evaluators: T.List[CorrectnessEvaluator],
    eval_timeout: int = 300,
) -> T.List[FlowgenEvaluationResult]:
    """Helper function to run multiple pair_eval_runners in parallel."""
    tasks = []
    for pair in dataset:
        tasks.append(asyncio.create_task(pair_eval_runner(pair, flow, evaluators)))

    await asyncio.wait(tasks, timeout=eval_timeout)

    all_results = []
    for t in tasks:
        try:
            r = t.result()
        except asyncio.exceptions.InvalidStateError as exc:
            # Providing empty result for proper reporting.
            exc.add_note(
                f"Eval of task {t} terminated due to timeout of {eval_timeout} seconds."
            )
            r = FlowgenEvaluationResult(
                qa_pair=None,
                run_time=np.nan,
                generation_exception=exc,
                llm_call_data=[],
            )
        all_results.append(r)
    return all_results


def _async_eval_runner(
    pair_eval_runner: T.Callable,
    items: T.List[core.QAPair],
    flow: Flow,
    eval_llms: T.List[FunctionCallingLLM],
) -> T.List[FlowgenEvaluationResult]:
    """Evaluate Q&A items asynchronously using provided pair_eval_runner."""
    evals = [CorrectnessEvaluator(llm=llm) for llm in eval_llms]

    try:
        loop = asyncio.get_event_loop()
    except RuntimeError:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)

    results = []
    num_batches = math.ceil(len(items) / cfg.evaluation.num_eval_batch)
    for i, batch in enumerate(itertools.batched(items, cfg.evaluation.num_eval_batch)):
        batch_result = loop.run_until_complete(
            _aeval_all_pair_runner(
                pair_eval_runner,
                list(batch),
                flow,
                evals,
            )
        )
        results.extend(batch_result)
        run_times = [
            r.run_time for r in results if r and r.run_time and not np.isnan(r.run_time)
        ]
        # Compute stats and pruners if we have successful evals
        # or, if we are over the max fail rate, proceed to calculate_metrics, which will
        # error out if we don't have any successful trials.
        # max_eval_failure_rate only applies if there are zero successful evals so far
        if run_times or i / num_batches > cfg.evaluation.max_eval_failure_rate:
            current_metrics = calculate_metrics(results)

            log.info(
                "Finished evaluation batch %s/%s with %s QA pairs. Metrics: %s, Flow: %s",
                i + 1,
                num_batches,
                len(batch),
                {
                    k: v
                    for k, v in current_metrics.items()
                    if k
                    in [
                        "accuracy",
                        "obj2_value",
                        "num_errors",
                    ]
                },
                flow,
            )
    return results


def async_eval_debias(
    items: T.List[core.QAPair],
    flow: Flow,
    eval_llms: T.List[FunctionCallingLLM] = EVAL_LLMS,
) -> T.List[FlowgenEvaluationResult]:
    """Evaluate Q&A items asynchronously with an evaluator chosen at random for each pair."""
    return _async_eval_runner(
        _aeval_pair_debias,
        items,
        flow,
        eval_llms,
    )


def validate_evaluation_data(results: T.List[FlowgenEvaluationResult]):
    generation_exceptions = []
    evaluation_exceptions = []
    run_times = [
        r.run_time for r in results if r and r.run_time and not np.isnan(r.run_time)
    ]
    call_data = [result.llm_call_data for result in results]
    costs = [sum(call.cost for call in calls) for calls in call_data]

    for ex_field in "generation_exception", "evaluation_exception":
        for result in results:
            ex: Exception = getattr(result, ex_field)
            if ex:
                if ex_field == "generation_exception":
                    generation_exceptions.append(ex)
                else:
                    evaluation_exceptions.append(ex)

    if generation_exceptions or evaluation_exceptions:
        exceptions = []
        if generation_exceptions:
            exceptions.append(
                ExceptionGroup(
                    "Exceptions during generation",
                    generation_exceptions,
                )
            )
        if evaluation_exceptions:
            exceptions.append(
                ExceptionGroup(
                    "Exceptions during evaluation",
                    evaluation_exceptions,
                )
            )
        exception_group = ExceptionGroup("Trial has failed evals", exceptions)
        log.warning(exception_group)
        if not run_times or not costs:
            raise Exception(exception_group)
    else:
        logging.info("The evaluation finished without exceptions")


def eval_dataset(
    dataset_iter,
    flow: Flow,
    evaluation_mode: T.Literal["random"] = "random",
    max_evals: int = 1000,
) -> T.Dict[str, T.Any]:
    eval_start = datetime.now(timezone.utc).timestamp()
    dataset = list(dataset_iter.iter_examples())[:max_evals]
    results: T.List[FlowgenEvaluationResult] = []
    match evaluation_mode:
        case "random":
            results = async_eval_debias(
                dataset,
                flow,
            )
        case _:
            raise NotImplementedError(f"{evaluation_mode} not implemented")

    metrics = calculate_metrics(results) if results else {}
    log.info("Number of evaluations: %d", metrics.get("num_total", 0))
    log.info("Number of successful evaluations: %d", metrics.get("num_success", 0))
    log.info("Number of errored evaluations: %d", metrics.get("num_errors", 0))
    eval_end = datetime.now(timezone.utc).timestamp()
    eval_duration = eval_end - eval_start
    metrics["eval_start"] = eval_start
    metrics["eval_end"] = eval_end
    metrics["eval_duration"] = eval_duration
    metrics["total_qa_pairs"] = len(dataset)
    return metrics


def calculate_metrics(
    results: T.List[FlowgenEvaluationResult],
) -> T.Dict[str, T.Any]:
    validate_evaluation_data(results)
    objective_2: str = cfg.evaluation.objective_2_name
    passing = [r.passing for r in results if r.passing in [True, False]]

    if len(passing) / len(results) < cfg.evaluation.min_reporting_success_rate:
        raise RuntimeError(
            f"Too few successful evaluations: {len(passing)} out of {len(results)}"
        )

    num_total = len(passing)
    num_errors = sum(
        int(res.passing is None)
        or int(res.generation_exception is not None)
        or int(res.evaluation_exception is not None)
        for res in results
    )
    acc = sum(passing) / num_total
    passing_std = np.std(passing)
    f1_scores = [
        core.f1_score(result.qa_pair.answer, result.response)
        for result in results
        if result and result.qa_pair and result.response
    ]
    f1_score = np.mean(f1_scores)
    run_times = [
        r.run_time for r in results if r and r.run_time and not np.isnan(r.run_time)
    ]
    min_time = float(np.min(run_times))
    max_time = float(np.max(run_times))
    mean_time = float(np.mean(run_times))
    median_time = float(np.median(run_times))
    p80_time = float(np.percentile(run_times, 80))
    run_times_std = float(np.std(run_times))
    run_times_p80 = [r for r in run_times if r <= p80_time]
    run_times_p80_std = float(np.std(run_times_p80))

    latency_data = extract_llm_latency_data(results)
    cost_data = extract_cost_data(results)
    token_data = extract_token_data(results)
    if objective_2 == "p80_time":
        obj2_value = p80_time
    else:
        call_data = [result.llm_call_data for result in results]
        costs = [sum(call.cost for call in calls) for calls in call_data]
        obj2_value = float(np.mean(costs))
    return {
        "accuracy": acc,
        "min_time": min_time,
        "max_time": max_time,
        "mean_time": mean_time,
        "median_time": median_time,
        "passing_std": passing_std,
        "f1_score": f1_score,
        "num_total": num_total,
        "num_errors": num_errors,
        "num_success": num_total - num_errors,
        "p80_time": p80_time,
        "run_times_std": run_times_std,
        "run_times_p80_std": run_times_p80_std,
        "obj2_value": obj2_value,
        "objective_1_name": "accuracy",
        "objective_2_name": objective_2,
        **cost_data,
        **token_data,
        **latency_data,
    }


def extract_llm_latency_data(
    all_results: T.List[FlowgenEvaluationResult],
) -> T.Dict[str, float]:
    call_data = [result.llm_call_data for result in all_results]
    per_model_latency = defaultdict(list)
    for calls in call_data:
        for call in calls:
            per_model_latency[call.llm_name].append(call.llm_call_latency)

    latency_data = {}
    for model, latencies in per_model_latency.items():
        latency_data[f"llm_latency_mean_{model}"] = float(np.mean(latencies))
        latency_data[f"llm_latency_median_{model}"] = float(np.median(latencies))
        latency_data[f"llm_latency_total_{model}"] = sum(latencies)
    return latency_data


def extract_cost_data(
    all_results: T.List[FlowgenEvaluationResult],
) -> T.Dict[str, float]:
    call_data = [result.llm_call_data for result in all_results]
    per_model_costs = defaultdict(list)
    for calls in call_data:
        for call in calls:
            per_model_costs[call.llm_name].append(call.cost)
    total_costs_per_model = {
        f"llm_cost_total_{model}": sum(costs)
        for model, costs in per_model_costs.items()
    }
    run_costs = [sum(call.cost for call in calls) for calls in call_data]
    return {
        "llm_cost_total": sum(run_costs),
        "llm_cost_min": float(np.min(run_costs)),
        "llm_cost_max": float(np.max(run_costs)),
        "llm_cost_mean": float(np.mean(run_costs)),
        "llm_cost_median": float(np.median(run_costs)),
        **total_costs_per_model,
    }


def extract_token_data(
    all_results: T.List[FlowgenEvaluationResult],
) -> T.Dict[str, float]:
    call_data = [result.llm_call_data for result in all_results]
    per_model_input_tokens = defaultdict(list)
    for calls in call_data:
        for call in calls:
            per_model_input_tokens[call.llm_name].append(call.input_tokens)
    total_input_tokens_per_model = {
        f"llm_input_tokens_total_{model}": sum(input_tokens)
        for model, input_tokens in per_model_input_tokens.items()
    }
    run_input_tokens = [sum(call.input_tokens for call in calls) for calls in call_data]

    per_model_output_tokens = defaultdict(list)
    for calls in call_data:
        for call in calls:
            per_model_output_tokens[call.llm_name].append(call.output_tokens)
    total_output_tokens_per_model = {
        f"llm_output_tokens_total_{model}": sum(output_tokens)
        for model, output_tokens in per_model_output_tokens.items()
    }
    run_output_tokens = [
        sum(call.output_tokens for call in calls) for calls in call_data
    ]
    return {
        "llm_input_tokens_total": sum(run_input_tokens),
        "llm_input_tokens_min": float(np.min(run_input_tokens)),
        "llm_input_tokens_max": float(np.max(run_input_tokens)),
        "llm_input_tokens_mean": float(np.mean(run_input_tokens)),
        "llm_input_tokens_median": float(np.median(run_input_tokens)),
        "llm_output_tokens_total": sum(run_output_tokens),
        "llm_output_tokens_min": float(np.min(run_output_tokens)),
        "llm_output_tokens_max": float(np.max(run_output_tokens)),
        "llm_output_tokens_mean": float(np.mean(run_output_tokens)),
        "llm_output_tokens_median": float(np.median(run_output_tokens)),
        **total_input_tokens_per_model,
        **total_output_tokens_per_model,
    }
