"""Experiment run management utilities.

This module provides utilities for setting up reproducible experiment runs
with unique identifiers based on configuration and timestamps.
"""

from __future__ import annotations

import hashlib
import os
import shutil
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict

from moltenflow.utils.config import load_yaml
from moltenflow.utils.io import save_json


@dataclass
class RunContext:
    """Context for an experiment run.

    Attributes:
        run_dir: Path to the timestamped run directory for all outputs
        config_hash: Short hash of the config file (first 8 chars of SHA256)
        timestamp: Timestamp string used for the run directory
        config: Loaded configuration dictionary
        experiment_dir: Parent directory for this experiment (config_hash level)
    """

    run_dir: str
    config_hash: str
    timestamp: str
    config: Dict[str, Any]
    experiment_dir: str


def compute_config_hash(config_path: str) -> str:
    """Compute deterministic hash from config file contents.

    Args:
        config_path: Path to YAML configuration file

    Returns:
        First 8 characters of SHA256 hash
    """
    with open(config_path, "rb") as f:
        content = f.read()
    return hashlib.sha256(content).hexdigest()[:8]


def setup_run(
    config_path: str,
    base_dir: str = "experiments",
    timestamp: str | None = None,
) -> RunContext:
    """Set up a new experiment run with unique directory structure.

    Creates a directory structure:
        {base_dir}/{config_hash}/
            config.yaml           # Copy of experiment config
            {timestamp}/          # Run-specific outputs go here

    Args:
        config_path: Path to YAML configuration file
        base_dir: Base directory for all experiments
        timestamp: Optional timestamp string (default: YYYYMMDD_HHMMSS)

    Returns:
        RunContext with paths and loaded configuration
    """
    # Compute config hash
    config_hash = compute_config_hash(config_path)

    # Generate timestamp if not provided
    if timestamp is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Create directory structure
    experiment_dir = os.path.join(base_dir, config_hash)
    run_dir = os.path.join(experiment_dir, timestamp)

    os.makedirs(run_dir, exist_ok=True)

    # Copy config to experiment directory (only if not already there)
    config_dest = os.path.join(experiment_dir, "config.yaml")
    if not os.path.exists(config_dest):
        shutil.copy2(config_path, config_dest)

    # Load configuration
    config = load_yaml(config_path)

    # Save run metadata
    run_metadata = {
        "config_hash": config_hash,
        "timestamp": timestamp,
        "config_path": str(Path(config_path).resolve()),
    }
    save_json(os.path.join(run_dir, "run_metadata.json"), run_metadata)

    return RunContext(
        run_dir=run_dir,
        config_hash=config_hash,
        timestamp=timestamp,
        config=config,
        experiment_dir=experiment_dir,
    )


def get_stage_dir(run_context: RunContext, stage: str) -> str:
    """Get directory for a specific pipeline stage.

    Args:
        run_context: The run context from setup_run()
        stage: Stage name (e.g., "pretrain", "finetune", "flow")

    Returns:
        Path to stage directory (created if not exists)
    """
    stage_dir = os.path.join(run_context.run_dir, stage)
    os.makedirs(stage_dir, exist_ok=True)
    return stage_dir
