import pytest
import pandas as pd
from datasets import Dataset
from llm_inference import output_parsers
from llm_inference.eval_utils import (
  Evaluator,
  EvaluationConfig,
  Metric,
)
import datasets as ds
import numpy as np


class DummyMetric(Metric):
  __execution_strategies__ = {"none", "process", "thread"}

  def compute(self, preds, refs):
    numeric_preds = np.array([float(p) if p != "[invalid]" else np.nan for p in preds])

    return pd.DataFrame({"dummy_score": np.abs(numeric_preds - refs)})

  def summarize(self, details, index_key):
    return {"mean_score": details["dummy_score"].mean()}


def get_reference(x):
  return x["value"]


def output_parser(x):
  try:
    return float(x)
  except:
    raise output_parsers.OutputParserException(f"Could not parse {x}")


@pytest.fixture
def dummy_config():
  return EvaluationConfig(
    metric=DummyMetric(),
    get_reference=get_reference,
    output_parser=[output_parser],
  )


@pytest.mark.parametrize(
  "predictions,references,expected_order",
  [
    # Test case 1: Predictions in same order as references
    (
      pd.DataFrame({"dataset_idx": [0, 1, 2], "prediction": ["1.0", "2.0", "3.0"]}),
      ds.Dataset.from_dict({"value": [1, 2, 3]}),
      [0, 1, 2],
    ),
    # Test case 2: Predictions in reverse order
    (
      pd.DataFrame({"dataset_idx": [2, 1, 0], "prediction": ["3.0", "2.0", "1.0"]}),
      ds.Dataset.from_dict({"value": [1, 2, 3]}),
      [2, 1, 0],
    ),
    # Test case 3: Predictions in random order
    (
      pd.DataFrame({"dataset_idx": [1, 0, 2], "prediction": ["2.0", "1.0", "3.0"]}),
      ds.Dataset.from_dict({"value": [1, 2, 3]}),
      [1, 0, 2],
    ),
    # Test case 4: Multiple predictions for same reference
    (
      pd.DataFrame(
        {
          "dataset_idx": [0, 0, 1, 2, 2],
          "prediction": ["1.0", "1.5", "2.0", "3.0", "3.5"],
        }
      ),
      ds.Dataset.from_dict({"value": [1, 2, 3]}),
      [0, 0, 1, 2, 2],
    ),
  ],
)
def test_evaluator_ordering(dummy_config, predictions, references, expected_order):
  """
  Test if the Evaluator preserves the order of predictions and handles multiple predictions correctly.

  This test checks:
  1. If the output details DataFrame has the same number of rows as the input predictions.
  2. If the dataset_idx in the output matches the expected order.
  3. If the extracted references correspond to the correct predictions.
  """
  evaluator = Evaluator(
    cfg=dummy_config,
    prediction="prediction",
    dataset_index_key="dataset_idx",
    other_index_keys=(),
  )
  _, details = evaluator.run(predictions, references)

  assert len(details) == len(
    predictions
  ), "Output should have same number of rows as input"
  assert (
    list(details["dataset_idx"]) == expected_order
  ), "Output order should match input order"
  assert list(details["extracted_ref"]) == [
    references[idx]["value"] for idx in expected_order
  ], "References should match predictions"


@pytest.mark.parametrize(
  "predictions,references",
  [
    # Test case: Predictions with indices not present in references
    (
      pd.DataFrame(
        {"dataset_idx": [0, 1, 2, 3], "prediction": ["1.0", "2.0", "3.0", "4.0"]}
      ),
      ds.Dataset.from_dict({"value": [1, 2, 3]}),
    ),
  ],
)
def test_evaluator_invalid_indices(dummy_config, predictions, references):
  """
  Test if the Evaluator correctly handles predictions with indices not present in references.

  This test checks if the Evaluator raises an appropriate error when given prediction indices
  that don't exist in the reference dataset.
  """
  evaluator = Evaluator(
    cfg=dummy_config,
    prediction="prediction",
    dataset_index_key="dataset_idx",
    other_index_keys=(),
  )
  with pytest.raises(IndexError):
    evaluator.run(predictions, references)


def test_evaluator_parse_error(dummy_config):
  """
  Test if the Evaluator correctly handles parsing errors in predictions.

  This test checks if the parse_error column is correctly populated when a prediction
  cannot be parsed by the output parser.
  """
  predictions = pd.DataFrame(
    {"dataset_idx": [0, 1, 2], "prediction": ["1.0", "invalid", "3.0"]}
  )
  references = ds.Dataset.from_dict({"value": [1, 2, 3]})

  evaluator = Evaluator(
    cfg=dummy_config,
    prediction="prediction",
    dataset_index_key="dataset_idx",
    other_index_keys=(),
  )
  _, details = evaluator.run(predictions, references)

  assert (
    details["parse_error"][1] is not None
  ), "Parse error should be recorded for invalid prediction"
  assert (
    details["parse_error"][0] is None
  ), "No parse error should be recorded for valid prediction"
