import collections
import functools
import pathlib

import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import upath
from loguru import logger
from tqdm.auto import tqdm
import multiprocessing as mp
import typing as tp


def validated_generation_scores(infer: pd.DataFrame, generation_scores: pd.DataFrame):
  join_keys = ["idx", "dataset_idx", "sample_idx"]
  errors = collections.defaultdict(list)

  infer["generation_token_ids"] = infer.apply(
    lambda row: tuple(np.concatenate([row["input_ids"], row["token_ids"]])), axis=1
  )
  deduped_infer = infer.drop_duplicates(subset=["generation_token_ids"])
  generation_scores["token_ids"] = generation_scores["token_ids"].apply(tuple)

  merged = pd.merge(
    deduped_infer,
    generation_scores,
    left_on=[*join_keys, "generation_token_ids"],
    right_on=[*join_keys, "token_ids"],
    how="outer",
    indicator=True,
  )

  # Check for unmatched rows
  unmatched = merged[merged["_merge"] != "both"]
  for _, row in unmatched.iterrows():
    error_info = {key: row[key] for key in join_keys}
    if row["_merge"] == "left_only":
      errors["missing_generation_score"].append(error_info)
    elif row["_merge"] == "right_only":
      errors["unexpected_generation_score"].append(error_info)
  return dict(errors)


def validate_gt_scores(infer: pd.DataFrame, gt_scores: pd.DataFrame):
  join_keys = ["idx", "dataset_idx", "sample_idx"]
  errors = collections.defaultdict(list)

  infer["input_ids"] = infer["input_ids"].apply(tuple)
  gt_scores["token_ids"] = gt_scores["token_ids"].apply(tuple)
  deduped_infer = infer.drop_duplicates(subset=["input_ids"])

  merged = pd.merge(
    deduped_infer,
    gt_scores,
    left_on=[*join_keys, "input_ids"],
    right_on=[*join_keys, "token_ids"],
    how="outer",
    indicator=True,
  )

  # Check for unmatched rows
  unmatched = merged[merged["_merge"] != "both"]
  for _, row in unmatched.iterrows():
    error_info = {key: row[key] for key in join_keys}
    if row["_merge"] == "left_only":
      errors["missing_gt_score"].append(error_info)
    elif row["_merge"] == "right_only":
      errors["unexpected_gt_score"].append(error_info)

  return dict(errors)


def _validate_hidden_states(args, detailed: bool = True):
  import os

  credentials, row = args
  os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credentials

  errors = collections.defaultdict(list)
  hidden_states_file = upath.UPath(row["hidden_states_file"])
  if not hidden_states_file.exists():
    errors["missing_file"].append(
      {
        "idx": row["idx"],
        "dataset_idx": row["dataset_idx"],
        "sample_idx": row["sample_idx"],
        "hidden_states": hidden_states_file.as_posix(),
      }
    )

    return dict(errors)
  if detailed:
    with hidden_states_file.open("rb") as f:
      hidden_states = np.load(f)
      for key in hidden_states.files:
        if hidden_states[key].shape[0] != len(row["token_ids"]):
          errors["mismatched_shape"].append(
            {
              "idx": row["idx"],
              "dataset_idx": row["dataset_idx"],
              "sample_idx": row["sample_idx"],
              "key": key,
              "shape": hidden_states[key].shape,
              "expected_shape": len(row["token_ids"]),
            }
          )
  return dict(errors)


def validate_hidden_states(
  scores: pd.DataFrame, root: str | pathlib.Path, detailed: bool = True
):
  import os

  gcp_credentials = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
  root = upath.UPath(root)
  errors = collections.defaultdict(list)
  scores = scores.copy()
  scores["hidden_states_file"] = scores["hidden_states"].apply(
    lambda x: (root / x).as_posix()
  )
  work_items = scores[
    ["idx", "dataset_idx", "sample_idx", "hidden_states_file", "token_ids"]
  ].to_dict(orient="records")
  work_items = [(gcp_credentials, row) for row in work_items]
  from multiprocessing.pool import ThreadPool

  with ThreadPool(mp.cpu_count() * 2) as pool:
    map_fn = pool.imap_unordered
    results = map_fn(
      functools.partial(_validate_hidden_states, detailed=detailed), work_items
    )
    results = tqdm(results, total=len(work_items), desc="Validating hidden states")
    for result in results:
      for error_type, error_info in result.items():
        errors[error_type].extend(error_info)
  return dict(errors)


def validate_experiment(
  experiment: str | upath.UPath,
  infer_file: str = "infer.parquet",
  gt_score_file: str = "gt-score.parquet",
  generation_scores_file: str = "generation-scores.parquet",
  eval_files: tp.Sequence[str] = ("eval.json", "eval.parquet"),
  detailed: bool = True,
):
  experiment = upath.UPath(experiment)
  results = collections.defaultdict(lambda: {"success": False, "errors": {}})

  # Check and load files
  files_to_load = {
    "infer": experiment / infer_file,
    "gt-score": experiment / gt_score_file,
    "generation-scores": experiment / generation_scores_file,
  }
  logger.info(
    f"Loading files: {dict((k, v.as_posix()) for k, v in files_to_load.items())}"
  )
  dataframes = {}
  for name, file_path in files_to_load.items():
    if file_path.exists():
      dataframes[name] = pq.read_table(file_path).to_pandas()
      results[name]["success"] = True
    else:
      results[name]["errors"]["file_missing"] = f"{file_path} does not exist"

  # Run validations if files exist
  if "infer" in dataframes and "gt-score" in dataframes:
    logger.info("Validating gt-scores")
    gt_score_errors = validate_gt_scores(dataframes["infer"], dataframes["gt-score"])
    if gt_score_errors:
      results["gt-score"]["success"] = False
      results["gt-score"]["errors"].update(gt_score_errors)

  if "infer" in dataframes and "generation-scores" in dataframes:
    logger.info("Validating generation scores")
    gen_score_errors = validated_generation_scores(
      dataframes["infer"], dataframes["generation-scores"]
    )
    if gen_score_errors:
      results["generation-scores"]["success"] = False
      results["generation-scores"]["errors"].update(gen_score_errors)

  if "gt-score" in dataframes:
    logger.info("Validating hidden states")
    hidden_state_errors = validate_hidden_states(
      dataframes["gt-score"], experiment, detailed=detailed
    )
    if hidden_state_errors:
      results["hidden_states"] = {"success": False, "errors": hidden_state_errors}
    else:
      results["hidden_states"] = {"success": True, "errors": {}}

  for eval_file in eval_files:
    eval_file = experiment / eval_file
    if eval_file.exists():
      results["eval"] = {"success": True, "errors": {}}
      break
  if "eval" not in results:
    results["eval"] = {
      "success": False,
      "errors": {"file_missing": f"{eval_files} missing"},
    }
  return dict(results)
