#!/usr/bin/env python3
"""
Script to check the status of experiment evaluations.
"""

import os
import sys
import argparse
import pandas as pd
from tqdm import tqdm
from Models.experiment_identifier import find_all_experiments, ExperimentIdentifier
from Models.result_data_checkpointer import ResultDataCheckpointer, SampleGenerationMetadata

def parse_args():
  parser = argparse.ArgumentParser(description='Check status of experiment evaluations')
  parser.add_argument('--config', type=str, help='Filter by config name')
  parser.add_argument('--model', type=str, nargs='+', help='Filter by model name (multiple allowed)')
  parser.add_argument('--objective', type=str, nargs='+', help='Filter by objective (multiple allowed)')
  parser.add_argument('--sde_type', type=str, nargs='+', help='Filter by SDE type (multiple allowed)')
  parser.add_argument('--freq', type=int, nargs='+', help='Filter by frequency (multiple allowed)')
  parser.add_argument('--group', type=str, nargs='+', help='Filter by group (multiple allowed)')
  parser.add_argument('--seed', type=int, nargs='+', help='Filter by seed (multiple allowed)')
  parser.add_argument('--incomplete', action='store_true', help='Show only incomplete evaluations')
  parser.add_argument('--complete', action='store_true', help='Show only complete evaluations')
  parser.add_argument('--id', type=str, help='Check specific experiment ID (format: config,objective,model,sde,freq,group,seed)')
  parser.add_argument('--details', action='store_true', help='Show detailed evaluation metadata')
  parser.add_argument('--save', action='store_true', help='Save summary to training_status.txt')

  return parser.parse_args()

def filter_checkpointers(checkpointers, args):
  """Filter checkpointers based on command line arguments."""
  filtered = []

  for ckpt in checkpointers:
    exp = ckpt.experiment_identifier
    if args.config and exp.config_name != args.config:
      continue
    if args.model and exp.model_name not in args.model:
      continue
    if args.objective and exp.get_model_objective() not in args.objective:
      continue
    if args.sde_type and exp.sde_type not in args.sde_type:
      continue
    if args.freq and exp.freq not in args.freq:
      continue
    if args.group and exp.group not in args.group:
      continue
    if args.seed and exp.global_key_seed not in args.seed:
      continue

    filtered.append(ckpt)

  return filtered

def get_checkpointers():
  """Find all experiments with evaluation checkpoints."""
  experiment_identifiers = find_all_experiments()
  checkpointers = [ResultDataCheckpointer(experiment_identifier=eid) for eid in experiment_identifiers]
  return checkpointers

def get_evaluation_statuses(checkpointers):
  """Get status information for evaluations."""
  statuses = []

  for ckpt in tqdm(checkpointers, desc="Fetching evaluation statuses"):
    ckpt: ResultDataCheckpointer

    # Get experiment info
    exp: ExperimentIdentifier = ckpt.experiment_identifier

    # Initialize status dictionary with experiment info
    status = {
      'config_name': exp.config_name,
      'model_name': exp.model_name,
      'objective': exp.objective,
      'sde_type': exp.sde_type,
      'freq': exp.freq,
      'group': exp.group,
      'seed': exp.global_key_seed,
    }

    # Check for metadata
    if ckpt.has_evaluation_metadata():
      metadata: SampleGenerationMetadata = ckpt.get_evaluation_metadata()
      metadata_dict = metadata.to_dict()

      # Add key metadata fields
      status.update({
        'is_complete': metadata_dict['completed'],
        'progress_percentage': metadata_dict['progress_percentage'],
        'test_data_size': metadata_dict['test_data_size'],
        'samples_for_empirical_distribution': metadata_dict['n_samples_for_empirical_distribution'],
        'evaluation_started': metadata_dict['evaluation_started'],
        'last_updated': metadata_dict['last_updated'],
      })

      # Add fields specific to EvaluationMetadata
      if 'highest_completed_index' in metadata_dict:
        completed_count = min(metadata_dict['highest_completed_index'] + 1, metadata_dict['test_data_size'])
        status.update({
          'highest_completed_index': metadata_dict['highest_completed_index'],
          'completed_count': completed_count,
          'total_count': metadata_dict['test_data_size'],
        })
      # For backwards compatibility with old metadata format
      elif 'completed_iterations' in metadata_dict:
        status.update({
          'completed_iterations': metadata_dict['completed_iterations'],
          'total_iterations': metadata_dict['total_expected_iterations'],
        })

      if metadata_dict['completed']:
        status['completion_time'] = metadata_dict['completion_time']
    else:
      # No metadata, but check if there are any checkpoints
      highest_index = ckpt.get_highest_completed_index()
      if highest_index >= 0:
        # Calculate progress based on highest completed index
        test_data_size = 0  # Unknown, can't calculate percentage

        # Get test_data_size from experiment config if possible
        config = None
        if hasattr(exp, 'create_config') and callable(getattr(exp, 'create_config')):
          config = exp.create_config()
          if isinstance(config, dict) and 'dataset' in config:
            if isinstance(config['dataset'], dict) and 'test_batch_size' in config['dataset']:
              test_data_size = config['dataset']['test_batch_size']

        progress_percentage = "Unknown"
        if test_data_size > 0:
          # Calculate percentage if we know the total size
          completed_count = highest_index + 1
          progress_percentage = f"{(completed_count / test_data_size) * 100:.1f}%"

        status.update({
          'is_complete': False,
          'progress_percentage': progress_percentage,
          'highest_completed_index': highest_index,
          'completed_count': highest_index + 1,
          'total_count': test_data_size if test_data_size > 0 else "Unknown",
          'last_updated': "No metadata available",
        })
      else:
        status.update({
          'is_complete': False,
          'progress_percentage': 0,
          'completed_count': 0,
          'total_count': "Unknown",
          'last_updated': "No checkpoints found",
        })

    statuses.append(status)

  return statuses

def main():
  args = parse_args()

  # Otherwise, check all experiments with optional filtering
  checkpointers = get_checkpointers()
  checkpointers = filter_checkpointers(checkpointers, args)

  if not checkpointers:
    print("No matching experiments found")
    return

  # Get status information
  statuses = get_evaluation_statuses(checkpointers)

  # Filter by completion status if requested
  if args.incomplete:
    statuses = [s for s in statuses if not s.get('is_complete', False)]
  if args.complete:
    statuses = [s for s in statuses if s.get('is_complete', False)]

  if not statuses:
    print("No matching evaluations found after filtering")
    return

  # Convert to DataFrame for nice display
  df = pd.DataFrame(statuses)

  # Reorder columns for better readability
  cols = ['config_name', 'model_name', 'objective', 'sde_type', 'freq', 'group', 'seed',
          'is_complete', 'progress_percentage']

  # Add completion tracking columns based on which format is present
  if 'completed_count' in df.columns:
    cols.extend(['completed_count', 'total_count'])
  if 'completed_iterations' in df.columns:
    cols.extend(['completed_iterations', 'total_iterations'])

  cols.append('last_updated')

  # Only include columns that exist
  display_cols = [col for col in cols if col in df.columns]
  df = df[display_cols]

  # Format progress percentage as string if it's not already
  if 'progress_percentage' in df.columns:
    df['progress_percentage'] = df['progress_percentage'].apply(
      lambda x: f"{x}%" if isinstance(x, (int, float)) else x
    )

  # Create summary text
  summary = []
  summary.append(f"Found {len(df)} matching evaluations")
  if 'is_complete' in df.columns:
    complete_count = df['is_complete'].sum()
    summary.append(f"Complete: {complete_count}, Incomplete: {len(df) - complete_count}")
  summary.append("\nEvaluation Status Summary:")
  summary.append(df.to_string(index=False))
  summary_text = "\n".join(summary)

  # Print to console
  print(summary_text)

  # Save to file if requested
  if args.save:
    save_file = f'evaluation_status.txt'
    with open(save_file, 'w') as f:
      f.write(summary_text)
    print(f"\nSummary saved to {save_file}")

if __name__ == '__main__':
  main()