import argparse
from pathlib import Path

import upath
from loguru import logger
from tabulate import tabulate

from llm_inference import utils
from llm_inference.validation import validate_experiment


def validate_experiments(
  root_dir: str | Path, single_experiment: bool = False, detailed: bool = True
):
  root_path = upath.UPath(root_dir)
  if single_experiment:
    experiments = [root_path]
  else:
    experiments = list(root_path.iterdir())

  results = []
  for exp_path in experiments:
    if not exp_path.is_dir():
      continue

    logger.info(f"Validating experiment: {exp_path.name}")
    validation_result = validate_experiment(exp_path, detailed=detailed)

    row = {
      "experiment": exp_path.name,
      "infer": validation_result.get("infer", {}).get("success", False),
      "gt-score": validation_result.get("gt-score", {}).get("success", False),
      "generation-scores": validation_result.get("generation-scores", {}).get(
        "success", False
      ),
      "hidden_states": validation_result.get("hidden_states", {}).get("success", False),
    }
    results.append(row)

    # Log detailed errors
    for component, result in validation_result.items():
      if not result["success"]:
        for error_type, error_details in result["errors"].items():
          logger.error(f"{exp_path.name} - {component} - {error_type}: {error_details}")

  return results


def main():
  parser = argparse.ArgumentParser(description="Validate experiments")
  parser.add_argument(
    "root_dir", type=str, help="Root directory containing experiments"
  )
  parser.add_argument(
    "--single",
    help="Name of a single experiment to validate",
    action="store_true",
  )
  parser.add_argument(
    "--light",
    help="Run light validation (e.g. don't check each hidden state)",
    action="store_true",
  )
  args = parser.parse_args()

  if args.root_dir.startswith("gs://"):
    if not utils.check_gcs_credentials():
      raise ValueError("GCS credentials not found")

  results = validate_experiments(args.root_dir, args.single, detailed=not args.light)

  # Print results in tabular format
  table = tabulate(results, headers="keys", tablefmt="pipe")
  print("\nValidation Results:")
  print(table)

  # Print summary
  fully_validated = sum(all(result.values()) for result in results)
  print(f"\nFully validated experiments: {fully_validated}/{len(results)}")


if __name__ == "__main__":
  main()
