import os
import json
import datetime
from typing import Optional, Dict, Any, TypeVar, Union, cast
import jax.numpy as jnp

# Use TypeVar for the experiment_identifier to avoid circular imports
ExperimentIdentifier = TypeVar('ExperimentIdentifier')
# Use TypeVar for TrainState to avoid circular imports
TrainState = TypeVar('TrainState')

"""
This module implements a system for tracking and managing the training status of experiments.
It provides:

1. TrainingMetadata - A class that tracks training progress, validation metrics, and completion status
2. Utility functions for saving, loading, and updating training metadata
3. Functions for checking training status and generating status reports

The training tracker enables persistent monitoring of long-running training jobs,
facilitates resumption of interrupted training, and provides standardized reporting
of training progress and completion status.
"""

class TrainingMetadata:
  """
  Class representing metadata for experiment training status.
  Tracks progress, validation metrics, and completion status.
  """

  def __init__(self,
               current_step: int = 0,
               best_validation_loss: float = float('inf'),
               steps_since_best_val_loss: int = 0,
               is_complete: bool = False,  # Kept for backward compatibility but no longer stored
               max_allowed_steps: Optional[int] = None,
               early_stopping_patience: int = 3,
               started: Optional[str] = None,
               last_updated: Optional[str] = None,
               completed: Optional[str] = None) -> None:
    self.current_step = current_step
    self.best_validation_loss = best_validation_loss
    self.steps_since_best_val_loss = steps_since_best_val_loss
    # is_complete is now a property
    self.max_allowed_steps = max_allowed_steps
    self.early_stopping_patience = early_stopping_patience

    # Timestamps
    self.started = started or self._get_current_time()
    self.last_updated = last_updated or self._get_current_time()
    self.completed = completed

    # If is_complete was True but doesn't match our calculation, update completion status
    if is_complete and not self.is_complete and not self.completed:
      self.completed = self._get_current_time()
    elif self.is_complete and not self.completed:
      self.completed = self._get_current_time()

  def _get_current_time(self) -> str:
    """Get current time in ISO format."""
    return datetime.datetime.now().isoformat()

  @property
  def is_complete(self) -> bool:
    """Dynamically calculate if training is complete based on current conditions."""
    return (self.steps_since_best_val_loss > self.early_stopping_patience or
            (self.max_allowed_steps is not None and self.current_step >= self.max_allowed_steps))

  def update_from_train_state(self, train_state: TrainState, max_allowed_steps: Optional[int] = None) -> None:
    """Update metadata from a TrainingState object."""
    self.current_step = int(train_state.i)  # type: ignore
    self.best_validation_loss = float(train_state.best_validation_loss)  # type: ignore
    self.steps_since_best_val_loss = int(train_state.number_of_steps_since_best_validation_loss)  # type: ignore

    if max_allowed_steps is not None:
      self.max_allowed_steps = max_allowed_steps

    # Update completion timestamp if newly completed
    was_complete = self.completed is not None
    if self.is_complete and not was_complete:
      self.completed = self._get_current_time()

    self.last_updated = self._get_current_time()

  def to_dict(self) -> Dict[str, Any]:
    """Convert to dictionary for serialization."""
    return {
      "current_step": self.current_step,
      "best_validation_loss": self.best_validation_loss,
      "steps_since_best_val_loss": self.steps_since_best_val_loss,
      "is_complete": self.is_complete,  # Store the calculated property
      "max_allowed_steps": self.max_allowed_steps,
      "early_stopping_patience": self.early_stopping_patience,
      "started": self.started,
      "last_updated": self.last_updated,
      "completed": self.completed
    }

  @classmethod
  def from_dict(cls, data: Dict[str, Any]) -> 'TrainingMetadata':
    """Create metadata from dictionary."""
    return cls(
      current_step=data["current_step"],
      best_validation_loss=data["best_validation_loss"],
      steps_since_best_val_loss=data["steps_since_best_val_loss"],
      is_complete=data["is_complete"], # Kept for backward compatibility but no longer stored
      max_allowed_steps=data["max_allowed_steps"],
      early_stopping_patience=data.get("early_stopping_patience", 3),
      started=data["started"],
      last_updated=data["last_updated"],
      completed=data.get("completed")
    )

  @property
  def progress_percentage(self) -> float:
    """Calculate the training progress percentage."""
    if self.max_allowed_steps is None or self.max_allowed_steps <= 0:
      return 100.0 if self.is_complete else 0.0
    return min(100.0, round(100 * self.current_step / self.max_allowed_steps, 2))

  def __str__(self) -> str:
    """Return a human-readable string representation of the metadata."""
    status = "Completed" if self.is_complete else "In Progress"
    return (
      f"Training Status: {status} ({self.progress_percentage:.1f}%)\n"
      f"Current Step: {self.current_step}" +
      (f"/{self.max_allowed_steps}" if self.max_allowed_steps else "") + "\n"
      f"Best Validation Loss: {self.best_validation_loss:.6f}\n"
      f"Steps Since Improvement: {self.steps_since_best_val_loss}\n"
      f"Started: {self.started}\n"
      f"Last Updated: {self.last_updated}\n"
      f"Completed: {self.completed or 'Not completed'}"
    )

  def __repr__(self) -> str:
    """Return a string representation of the metadata."""
    return str(self)

def reset_training_metadata(experiment_identifier: ExperimentIdentifier) -> TrainingMetadata:
  """
  Reset training metadata for an experiment.

  This creates a new metadata instance with default values and saves it,
  overwriting any existing metadata. Use this when retraining an experiment
  from scratch.

  Args:
    experiment_identifier: The experiment to reset metadata for

  Returns:
    The newly created metadata instance
  """
  metadata = TrainingMetadata()  # Create new metadata with default values
  save_training_metadata(experiment_identifier, metadata)
  print(f"Reset training metadata for experiment")
  return metadata

def get_training_metadata_path(experiment_identifier: ExperimentIdentifier) -> str:
  """Get path to training metadata file for an experiment."""
  return os.path.join(experiment_identifier.model_folder_name, "training_metadata.json")

def has_training_metadata(experiment_identifier: ExperimentIdentifier) -> bool:
  """Check if metadata exists for an experiment."""
  metadata_path = get_training_metadata_path(experiment_identifier)
  return os.path.exists(metadata_path)

def save_training_metadata(experiment_identifier: ExperimentIdentifier, metadata: TrainingMetadata) -> None:
  """Save metadata for an experiment."""
  metadata_path = get_training_metadata_path(experiment_identifier)
  os.makedirs(os.path.dirname(metadata_path), exist_ok=True)

  with open(metadata_path, 'w') as f:
    json.dump(metadata.to_dict(), f, indent=2)

def load_training_metadata(experiment_identifier: ExperimentIdentifier) -> Optional[TrainingMetadata]:
  """Load metadata for an experiment."""
  metadata_path = get_training_metadata_path(experiment_identifier)

  if not os.path.exists(metadata_path):
    return None

  with open(metadata_path, 'r') as f:
    data = json.load(f)

  return TrainingMetadata.from_dict(data)

def update_training_metadata_from_train_state(
    experiment_identifier: ExperimentIdentifier,
    train_state: TrainState,
    max_steps: Optional[int] = None
) -> TrainingMetadata:
  """Update metadata from a train state."""
  metadata = load_training_metadata(experiment_identifier)
  is_new = False

  if metadata is None:
    metadata = TrainingMetadata()
    is_new = True

  # Store previous values to report changes
  prev_val_loss = metadata.best_validation_loss
  prev_complete = metadata.is_complete

  # Update metadata
  metadata.update_from_train_state(train_state, max_steps)

  # Only log when creating new metadata or significant changes occur
  if is_new:
    print(f"Created new training metadata for experiment (step: {metadata.current_step})")
  elif metadata.best_validation_loss < prev_val_loss:
    print(f"Training progress: Step {metadata.current_step}, val loss improved: {prev_val_loss:.6f} → {metadata.best_validation_loss:.6f}")
  elif metadata.is_complete and not prev_complete:
    print(f"Training complete at step {metadata.current_step}! Best val loss: {metadata.best_validation_loss:.6f}")
    if metadata.steps_since_best_val_loss > metadata.early_stopping_patience:
      print(f"  Stopped early: No improvement for {metadata.steps_since_best_val_loss} steps")

  save_training_metadata(experiment_identifier, metadata)
  return metadata

def get_training_status(experiment_identifier: ExperimentIdentifier) -> Dict[str, Any]:
  """Get training status summary for an experiment."""
  metadata = load_training_metadata(experiment_identifier)

  if metadata is None:
    return {
      "is_complete": False,
      "status_summary": "No metadata available"
    }

  if metadata.is_complete:
    status_summary = f"Training complete after {metadata.current_step} steps, best val loss: {metadata.best_validation_loss:.6f}"
  elif metadata.max_allowed_steps:
    status_summary = f"Training in progress: {metadata.progress_percentage:.1f}% ({metadata.current_step}/{metadata.max_allowed_steps} steps)"
  else:
    status_summary = f"Training in progress: {metadata.current_step} steps completed"

  return {
    "is_complete": metadata.is_complete,
    "current_step": metadata.current_step,
    "max_steps": metadata.max_allowed_steps,
    "best_val_loss": metadata.best_validation_loss,
    "steps_since_improvement": metadata.steps_since_best_val_loss,
    "status_summary": status_summary
  }






import shutil
from typing import List, Optional, Tuple
def migrate_all_training_metadata_files(base_dir: str, dry_run: bool = False) -> List[Tuple[str, str]]:
  """
  Find and migrate all training_metadata.json files to their respective model_folder locations.

  Args:
    base_dir: Base directory to start the search
    dry_run: If True, only print what would be done without performing file operations

  Returns:
    List of tuples (source_path, destination_path) for all migrated files
  """
  import glob
  import time

  migrated_files = []
  failed_paths = []

  start_time = time.time()
  print(f"Searching for seed folders in {base_dir}...")

  # Use glob to directly find all seed_* folders
  seed_pattern = os.path.join(base_dir, "**", "seed_*")
  seed_folders = [f for f in glob.glob(seed_pattern, recursive=True) if os.path.isdir(f)]

  print(f"Found {len(seed_folders)} seed folders in {time.time() - start_time:.1f} seconds")

  # Process each seed folder
  for i, seed_folder in enumerate(seed_folders):
    if i % 50 == 0:
      print(f"Processing folder {i+1}/{len(seed_folders)}...")

    # Check if this folder contains a training_metadata.json file
    source_path = os.path.join(seed_folder, "training_metadata.json")
    if not os.path.exists(source_path):
      continue

    # Get relative path from base_dir
    rel_path = os.path.relpath(seed_folder, base_dir)
    path_parts = rel_path.split(os.sep)

    # Need at least 7 path components to create an experiment identifier
    if len(path_parts) >= 7:
      try:
        # Extract the 7 path components needed for the experiment identifier
        config_name = path_parts[-7]
        objective = path_parts[-6]
        model_name = path_parts[-5]
        sde_type = path_parts[-4]
        # Handle the freq and seed components which may have different formats
        freq = path_parts[-3]
        if not freq.startswith("freq_"):
          freq = f"freq_{freq.split('_')[1] if '_' in freq else freq}"

        group = path_parts[-2]

        seed = path_parts[-1]
        if not seed.startswith("seed_"):
          seed = f"seed_{seed.split('_')[1] if '_' in seed else seed}"

        # Create model identifier tuple
        model_id = (
          config_name,
          objective,
          model_name,
          sde_type,
          freq,
          group,
          seed
        )

        # Create experiment identifier from model ID
        experiment_id = ExperimentIdentifier.from_model_identifier(model_id)

        # Get new path using the updated get_metadata_path function
        dest_path = os.path.join(experiment_id.model_folder_name, "training_metadata.json")

        if not dry_run:
          # Ensure the destination directory exists
          os.makedirs(os.path.dirname(dest_path), exist_ok=True)

          # Copy the file
          shutil.copy2(source_path, dest_path)
          if i % 50 == 0:  # Only print occasional successes
            print(f"Migrated: {source_path} → {dest_path}")
        else:
          if i % 50 == 0:  # Only print occasional messages in dry run
            print(f"Would migrate: {source_path} → {dest_path}")

        migrated_files.append((source_path, dest_path))

      except Exception as e:
        failed_paths.append((source_path, str(e)))
        print(f"Failed to process {source_path}: {e}")
    else:
      failed_paths.append((source_path, f"Path does not have enough components ({len(path_parts)})"))
      print(f"Skipping {source_path}: Path does not have enough components ({len(path_parts)})")

  total_time = time.time() - start_time
  print(f"\nMigration {'would have ' if dry_run else ''}completed in {total_time:.1f} seconds:")
  print(f"- {len(migrated_files)} files migrated")
  print(f"- {len(failed_paths)} files failed")

  if failed_paths and len(failed_paths) <= 10:
    print("\nFailed paths:")
    for path, reason in failed_paths:
      print(f"  {path}: {reason}")
  elif failed_paths:
    print("\nFirst 10 failed paths:")
    for path, reason in failed_paths[:10]:
      print(f"  {path}: {reason}")
    print(f"  ... and {len(failed_paths) - 10} more failures")

  return migrated_files

if __name__ == "__main__":
  from Models.experiment_identifier import SAVE_PATH, ExperimentIdentifier
  migrate_all_training_metadata_files(base_dir=SAVE_PATH, dry_run=False)