#!/usr/bin/env python3
"""
Script to check the status of training experiments.
"""

import os
import sys
import argparse
import pandas as pd
from tqdm import tqdm
from Models.experiment_identifier import ExperimentIdentifier, find_all_experiments

def parse_args():
  parser = argparse.ArgumentParser(description='Check status of training experiments')
  parser.add_argument('--config', type=str, help='Filter by config name')
  parser.add_argument('--model', type=str, nargs='+', help='Filter by model name (multiple allowed)')
  parser.add_argument('--objective', type=str, nargs='+', help='Filter by objective (multiple allowed)')
  parser.add_argument('--sde_type', type=str, nargs='+', help='Filter by SDE type (multiple allowed)')
  parser.add_argument('--freq', type=int, nargs='+', help='Filter by frequency (multiple allowed)')
  parser.add_argument('--group', type=str, nargs='+', help='Filter by group (multiple allowed)')
  parser.add_argument('--seed', type=int, nargs='+', help='Filter by seed (multiple allowed)')
  parser.add_argument('--incomplete', action='store_true', help='Show only incomplete experiments')
  parser.add_argument('--complete', action='store_true', help='Show only complete experiments')
  parser.add_argument('--save', action='store_true', help='Save summary to specified file')

  return parser.parse_args()

def filter_experiments(experiments, args):
  """Filter experiments based on command line arguments."""
  filtered = []

  for exp in experiments:
    if args.config and exp.config_name != args.config:
      continue
    if args.model and exp.model_name not in args.model:
      continue
    if args.objective and exp.get_model_objective() not in args.objective:
      continue
    if args.sde_type and exp.sde_type not in args.sde_type:
      continue
    if args.freq and exp.freq not in args.freq:
      continue
    if args.group and exp.group not in args.group:
      continue
    if args.seed and exp.global_key_seed not in args.seed:
      continue

    filtered.append(exp)

  return filtered
def get_experiment_statuses(experiments):
  """Get status information for experiments."""
  statuses = []

  for exp in tqdm(experiments, desc="Fetching experiment statuses"):
    exp: ExperimentIdentifier
    status = exp.experiment_training_status()

    # Add experiment information
    status.update({
      'config_name': exp.config_name,
      'model_name': exp.model_name,
      'objective': exp.objective,
      'sde_type': exp.sde_type,
      'freq': exp.freq,
      'group': exp.group,
      'seed': exp.global_key_seed,
    })

    statuses.append(status)

  return statuses

def main():
  args = parse_args()

  # Otherwise, check all experiments with optional filtering
  experiments = find_all_experiments()
  experiments = filter_experiments(experiments, args)

  if not experiments:
    print("No matching experiments found")
    return

  # Get status information
  statuses = get_experiment_statuses(experiments)

  # Filter by completion status if requested
  if args.incomplete:
    statuses = [s for s in statuses if not s.get('is_complete', False)]
  if args.complete:
    statuses = [s for s in statuses if s.get('is_complete', False)]

  if not statuses:
    print("No matching experiments found after filtering")
    return

  # Convert to DataFrame for nice display
  df = pd.DataFrame(statuses)

  # Reorder columns for better readability
  cols = ['config_name', 'model_name', 'objective', 'sde_type', 'freq', 'group', 'seed',
          'is_complete', 'current_step', 'max_steps', 'best_val_loss', 'steps_since_improvement']

  # Only include columns that exist
  display_cols = [col for col in cols if col in df.columns]
  df = df[display_cols]

  # Add progress percentage
  if 'current_step' in df.columns and 'max_steps' in df.columns:
    df['progress'] = df.apply(lambda row:
                             '100%' if row.get('is_complete', False) else
                             f"{min(100, round(100 * row['current_step'] / row['max_steps'], 1))}%"
                             if pd.notnull(row['max_steps']) and row['max_steps'] > 0 else "?%",
                             axis=1)

  # Create summary text
  summary = []
  summary.append(f"Found {len(df)} matching experiments")
  if 'is_complete' in df.columns:
    complete_count = df['is_complete'].sum()
    summary.append(f"Complete: {complete_count}, Incomplete: {len(df) - complete_count}")
  summary.append("\nExperiment Status Summary:")
  summary.append(df.to_string(index=False))
  summary_text = "\n".join(summary)

  # Print to console
  print(summary_text)

  # Save to file if requested
  if args.save:
    save_file = f'training_status.txt'
    with open(save_file, 'w') as f:
      f.write(summary_text)
    print(f"\nSummary saved to {save_file}")

if __name__ == '__main__':
  main()