"""This script will preprocess experiment data,.
Useful to download, preprocess and cache experiment data
before running eval.
"""

import pathlib
import pandas as pd
from hydra_zen import builds
import json
from loguru import logger
import numpy as np
from hydra_zen import store
import typing as tp


def load_experiment(
  root: pathlib.Path,
  hidden_states=False,
  hidden_state_token: int = -1,
  hidden_state_layer: int = -1,
  hidden_states_path="score-gt.parquet",
):
  infer_df = pd.read_parquet(root / "infer.parquet")
  with open(root / "eval.json") as f:
    eval_results = json.load(f)["aux"]["results"]
    eval_df = pd.DataFrame(eval_results)

  gt_scores = pd.read_parquet(root / hidden_states_path)

  # Merge the dataframes
  merged = pd.merge(
    infer_df,
    eval_df[["dataset_idx", "sample_idx", "correct"]],
    on=["dataset_idx", "sample_idx"],
  )
  gt_scores["prompt_logprobs"] = gt_scores["token_logprobs"].apply(lambda x: np.sum(x))
  merged = pd.merge(
    merged, gt_scores[["dataset_idx", "prompt_logprobs"]], on=["dataset_idx"]
  )
  merged = merged.rename(columns={"correct": "admissible"})

  assert len(merged) == len(infer_df)

  return merged


def preprocess_experiment_data(
  path: pathlib.Path,
  load_hidden_states: bool = False,
  hidden_states_path: str = "score-gt.parquet",
  hidden_state_token: int = -1,
  hidden_state_layer: int = -1,
  output_path: str | None = None,
):
  merged_data = load_experiment(
    path, hidden_states=load_hidden_states, hidden_states_path=hidden_states_path
  )

  logger.info(f"Loaded {len(merged_data)} samples from {path}")


Config = builds(preprocess_experiment_data, populate_full_signature=True)


if __name__ == "__main__":
  from hydra_zen import zen

  store.add_to_hydra_store()
  zen(preprocess_experiment_data).hydra_main(config_name="root_cfg", version_base="1.3")
