#!/usr/bin/env python3
"""
Script to check the status of experiment metric calculations.
This script identifies which experiments have complete metrics and which ones still need metrics to be calculated.
"""

import os
import sys
import argparse
import pandas as pd
from tqdm import tqdm
from Models.experiment_identifier import ExperimentIdentifier, find_all_experiments

def get_metric_status(experiment_identifier):
  """
  Check which metrics are computed for a specific experiment in the results file.

  Args:
    experiment_identifier: ExperimentIdentifier instance for the experiment to check

  Returns:
    dict: Dictionary containing metric status information including:
      - is_complete: Whether all metrics are complete
      - in_results_file: Whether the experiment is in the results file
      - missing_metrics: List of missing metric names
      - completion_percentage: Percentage of completed metrics
  """
  # All possible settings and metrics
  settings = experiment_identifier.get_evaluation_settings()
  all_metrics = experiment_identifier.get_metrics_to_compute()

  # Get experiment identifier as string
  exp_id = experiment_identifier.get_model_identifier()
  exp_id_str = str(exp_id)

  # Get results file path
  results_path = experiment_identifier.result_metric_path

  # Generate all possible metric columns
  all_metric_columns = [f"{setting}_{metric}" for setting in settings for metric in all_metrics]

  # Check if results file exists
  if not os.path.exists(results_path):
    return {
      'is_complete': False,
      'in_results_file': False,
      'missing_metrics': all_metric_columns,
      'completion_percentage': 0.0
    }

  # Load the results DataFrame
  try:
    results_df = pd.read_csv(results_path, index_col=0)
  except Exception as e:
    print(f"Error loading results file: {e}")
    return {
      'is_complete': False,
      'in_results_file': False,
      'missing_metrics': all_metric_columns,
      'completion_percentage': 0.0
    }

  # Check if experiment is in results
  in_results = exp_id_str in results_df.index

  # Check which metrics are missing
  missing_metrics = []
  if in_results:
    for column in all_metric_columns:
      if column not in results_df.columns or pd.isna(results_df.loc[exp_id_str, column]):
        missing_metrics.append(column)
  else:
    missing_metrics = all_metric_columns

  # Calculate completion percentage
  total_metrics = len(all_metric_columns)
  completed_metrics = total_metrics - len(missing_metrics)
  completion_percentage = (completed_metrics / total_metrics) * 100 if total_metrics > 0 else 0.0

  # Return status
  return {
    'is_complete': len(missing_metrics) == 0 and in_results,
    'in_results_file': in_results,
    'missing_metrics': missing_metrics,
    'completion_percentage': completion_percentage,
    'completed_metrics': completed_metrics,
    'total_metrics': total_metrics
  }

def parse_args():
  parser = argparse.ArgumentParser(description='Check status of experiment metrics')
  parser.add_argument('--config', type=str, help='Filter by config name')
  parser.add_argument('--model', type=str, nargs='+', help='Filter by model name (multiple allowed)')
  parser.add_argument('--objective', type=str, nargs='+', help='Filter by objective (multiple allowed)')
  parser.add_argument('--sde_type', type=str, nargs='+', help='Filter by SDE type (multiple allowed)')
  parser.add_argument('--freq', type=int, nargs='+', help='Filter by frequency (multiple allowed)')
  parser.add_argument('--group', type=str, nargs='+', help='Filter by group (multiple allowed)')
  parser.add_argument('--seed', type=int, nargs='+', help='Filter by seed (multiple allowed)')
  parser.add_argument('--incomplete', action='store_true', help='Show only experiments with incomplete metrics')
  parser.add_argument('--complete', action='store_true', help='Show only experiments with complete metrics')
  parser.add_argument('--missing', action='store_true', help='Show detailed information about missing metrics')

  return parser.parse_args()

def filter_experiments(experiments, args):
  """Filter experiments based on command line arguments."""
  filtered = []

  for exp in experiments:
    if args.config and exp.config_name != args.config:
      continue
    if args.model and exp.model_name not in args.model:
      continue
    if args.objective and exp.get_model_objective() not in args.objective:
      continue
    if args.sde_type and exp.sde_type not in args.sde_type:
      continue
    if args.freq and exp.freq not in args.freq:
      continue
    if args.group and exp.group not in args.group:
      continue
    if args.seed and exp.global_key_seed not in args.seed:
      continue

    filtered.append(exp)

  return filtered

def get_metric_statuses(experiments):
  """Check which metrics are computed for each experiment."""
  statuses = []

  # Check each experiment
  for exp in tqdm(experiments, desc="Checking metric status"):
    status = get_metric_status(exp)

    # Add experiment information to the status
    status.update({
      'experiment_id': exp.get_model_identifier(),
      'config_name': exp.config_name,
      'model_name': exp.model_name,
      'objective': exp.objective,
      'sde_type': exp.sde_type,
      'freq': exp.freq,
      'group': exp.group,
      'seed': exp.global_key_seed
    })

    statuses.append(status)

  return statuses

def main():
  args = parse_args()

  # Find all experiments
  all_experiments = find_all_experiments()
  if not all_experiments:
    print("No experiments found")
    return

  # Filter experiments based on command line arguments
  experiments = filter_experiments(all_experiments, args)
  if not experiments:
    print("No experiments match the specified filters")
    return

  # Get metric statuses for each experiment
  statuses = get_metric_statuses(experiments)

  # Filter by completion status if requested
  if args.incomplete:
    statuses = [s for s in statuses if not s['is_complete']]
  if args.complete:
    statuses = [s for s in statuses if s['is_complete']]

  if not statuses:
    print("No experiments match the filters after checking completion status")
    return

  # Convert to DataFrame for nice display
  df = pd.DataFrame(statuses)

  # Reorder columns for better readability
  cols = [
    'config_name', 'model_name', 'objective', 'sde_type', 'freq', 'group', 'seed',
    'is_complete', 'in_results_file', 'completion_percentage'
  ]

  # Only include columns that exist
  display_cols = [col for col in cols if col in df.columns]
  df = df[display_cols]

  # Format completion percentage
  df['completion_percentage'] = df['completion_percentage'].apply(lambda x: f"{x:.1f}%")

  # Print summary
  print(f"Found {len(df)} matching experiments")
  complete_count = df['is_complete'].sum()
  print(f"Complete metrics: {complete_count}, Incomplete: {len(df) - complete_count}")

  # Calculate average completion percentage
  avg_completion = df['completion_percentage'].str.rstrip('%').astype(float).mean()
  print(f"Average completion: {avg_completion:.1f}%")

  print("\nMetric Status Summary:")
  print(df.to_string(index=False))

  # Print missing metrics details if requested
  if args.missing:
    print("\nMissing Metrics Details:")
    for status in statuses:
      if not status['is_complete'] and status['missing_metrics']:
        print(f"\n{status['config_name']}, {status['model_name']}, {status['sde_type']}, freq_{status['freq']}, {status['group']}, seed_{status['seed']}:")

        # Group missing metrics by setting
        settings = {}
        for metric in status['missing_metrics']:
          setting, metric_name = metric.split('_', 1)
          if setting not in settings:
            settings[setting] = []
          settings[setting].append(metric_name)

        # Print missing metrics grouped by setting
        for setting, metrics in settings.items():
          print(f"  {setting}: {', '.join(metrics)}")

if __name__ == '__main__':
  main()