"""
Input/Output utilities for saving and loading results.
"""

from __future__ import annotations

from typing import Any, Dict, Optional, Union
from pathlib import Path
import json
import pickle
from datetime import datetime

import numpy as np


def ensure_dir(path: Union[str, Path]) -> Path:
    """
    Ensure a directory exists, creating it if necessary.

    Args:
        path: Directory path

    Returns:
        Path object
    """
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)
    return path


def save_results(
    results: Dict[str, Any],
    path: Union[str, Path],
    format: str = 'json'
) -> None:
    """
    Save results to file.

    Args:
        results: Results dictionary
        path: Output path
        format: 'json' or 'pickle'
    """
    path = Path(path)
    ensure_dir(path.parent)

    if format == 'json':
        # Convert numpy arrays to lists for JSON serialization
        results_json = _make_json_serializable(results)
        with open(path, 'w') as f:
            json.dump(results_json, f, indent=2)

    elif format == 'pickle':
        with open(path, 'wb') as f:
            pickle.dump(results, f)

    else:
        raise ValueError(f"Unknown format: {format}")


def load_results(
    path: Union[str, Path],
    format: Optional[str] = None
) -> Dict[str, Any]:
    """
    Load results from file.

    Args:
        path: Input path
        format: 'json' or 'pickle' (auto-detected if None)

    Returns:
        Results dictionary
    """
    path = Path(path)

    if format is None:
        format = 'json' if path.suffix == '.json' else 'pickle'

    if format == 'json':
        with open(path, 'r') as f:
            return json.load(f)

    elif format == 'pickle':
        with open(path, 'rb') as f:
            return pickle.load(f)

    else:
        raise ValueError(f"Unknown format: {format}")


def _make_json_serializable(obj: Any) -> Any:
    """
    Convert object to JSON-serializable form.

    Args:
        obj: Object to convert

    Returns:
        JSON-serializable version
    """
    if isinstance(obj, dict):
        return {k: _make_json_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [_make_json_serializable(item) for item in obj]
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.integer, np.int64, np.int32)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float64, np.float32)):
        return float(obj)
    elif isinstance(obj, (np.bool_, bool)):
        return bool(obj)
    elif isinstance(obj, set):
        return list(obj)
    elif isinstance(obj, frozenset):
        return list(obj)
    elif hasattr(obj, 'to_dict'):
        return _make_json_serializable(obj.to_dict())
    elif hasattr(obj, '__dict__'):
        return _make_json_serializable(obj.__dict__)
    else:
        return obj


def save_numpy(
    array: np.ndarray,
    path: Union[str, Path],
    compressed: bool = True
) -> None:
    """
    Save numpy array to file.

    Args:
        array: Array to save
        path: Output path
        compressed: Whether to use compression
    """
    path = Path(path)
    ensure_dir(path.parent)

    if compressed:
        np.savez_compressed(path, data=array)
    else:
        np.save(path, array)


def load_numpy(path: Union[str, Path]) -> np.ndarray:
    """
    Load numpy array from file.

    Args:
        path: Input path

    Returns:
        Numpy array
    """
    path = Path(path)

    if path.suffix == '.npz':
        with np.load(path) as data:
            return data['data']
    else:
        return np.load(path)


def generate_output_filename(
    experiment_name: str,
    suffix: str = '',
    extension: str = 'json',
    include_timestamp: bool = True
) -> str:
    """
    Generate a unique output filename.

    Args:
        experiment_name: Name of the experiment
        suffix: Additional suffix
        extension: File extension
        include_timestamp: Whether to include timestamp

    Returns:
        Filename string
    """
    parts = [experiment_name]

    if suffix:
        parts.append(suffix)

    if include_timestamp:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        parts.append(timestamp)

    return "_".join(parts) + f".{extension}"


def save_checkpoint(
    state: Dict[str, Any],
    checkpoint_dir: Union[str, Path],
    name: str,
    iteration: int
) -> Path:
    """
    Save experiment checkpoint.

    Args:
        state: State dictionary
        checkpoint_dir: Directory for checkpoints
        name: Checkpoint name
        iteration: Current iteration

    Returns:
        Path to saved checkpoint
    """
    checkpoint_dir = ensure_dir(checkpoint_dir)
    filename = f"{name}_iter{iteration:06d}.pkl"
    path = checkpoint_dir / filename

    with open(path, 'wb') as f:
        pickle.dump({
            'state': state,
            'iteration': iteration,
            'timestamp': datetime.now().isoformat(),
        }, f)

    return path


def load_checkpoint(path: Union[str, Path]) -> Dict[str, Any]:
    """
    Load experiment checkpoint.

    Args:
        path: Path to checkpoint

    Returns:
        Checkpoint dictionary with 'state', 'iteration', 'timestamp'
    """
    with open(path, 'rb') as f:
        return pickle.load(f)


def find_latest_checkpoint(
    checkpoint_dir: Union[str, Path],
    name: str
) -> Optional[Path]:
    """
    Find the latest checkpoint for an experiment.

    Args:
        checkpoint_dir: Directory containing checkpoints
        name: Checkpoint name prefix

    Returns:
        Path to latest checkpoint or None
    """
    checkpoint_dir = Path(checkpoint_dir)

    if not checkpoint_dir.exists():
        return None

    checkpoints = list(checkpoint_dir.glob(f"{name}_iter*.pkl"))

    if not checkpoints:
        return None

    # Sort by iteration number
    checkpoints.sort(key=lambda p: int(p.stem.split('iter')[1]))

    return checkpoints[-1]


class ResultsWriter:
    """Helper class for incrementally writing results."""

    def __init__(
        self,
        output_dir: Union[str, Path],
        experiment_name: str,
        format: str = 'json'
    ):
        self.output_dir = ensure_dir(output_dir)
        self.experiment_name = experiment_name
        self.format = format
        self.results = {
            'experiment': experiment_name,
            'start_time': datetime.now().isoformat(),
            'data': [],
            'metadata': {},
        }

    def add_result(self, result: Dict[str, Any]) -> None:
        """Add a single result."""
        self.results['data'].append(result)

    def set_metadata(self, key: str, value: Any) -> None:
        """Set metadata value."""
        self.results['metadata'][key] = value

    def save(self, filename: Optional[str] = None) -> Path:
        """
        Save all results to file.

        Args:
            filename: Optional custom filename

        Returns:
            Path to saved file
        """
        self.results['end_time'] = datetime.now().isoformat()
        self.results['n_results'] = len(self.results['data'])

        if filename is None:
            filename = generate_output_filename(
                self.experiment_name,
                extension=self.format
            )

        path = self.output_dir / filename
        save_results(self.results, path, format=self.format)

        return path
