"""
Save utilities for GLEAM-AI.

This module contains utilities for saving data, model checkpoints, and results
in the GLEAM-AI epidemiological forecasting system.
"""

import numpy as np
import torch
from pathlib import Path
from typing import Union, List, Dict, Any, Optional
import json
import pickle


def save_x_data(x: np.ndarray, file_ids: np.ndarray, run_ids: List, dest_path: Union[str, Path]) -> None:
    """
    Save x data to files.
    
    Args:
        x: Input features [num_nodes, x_dim]
        file_ids: File identifiers
        run_ids: Run identifiers
        dest_path: Destination directory path
    """
    if not isinstance(dest_path, Path):
        dest_path = Path(dest_path)
    
    dest_path.mkdir(exist_ok=True, parents=True)
    
    for fid, rids, xx in zip(file_ids, run_ids, x):
        num_repeat = len(rids)
        data = np.repeat(xx[np.newaxis, ...], num_repeat, axis=0)
        filename = f"{fid}_x.npy"
        np.save(dest_path / filename, data)


def save_xt_data(xt: np.ndarray, file_ids: np.ndarray, run_ids: List, dest_path: Union[str, Path]) -> None:
    """
    Save temporal features data to files.
    
    Args:
        xt: Temporal features [B, L, num_nodes, xt_dim]
        file_ids: File identifiers
        run_ids: Run identifiers
        dest_path: Destination directory path
    """
    if not isinstance(dest_path, Path):
        dest_path = Path(dest_path)
    
    dest_path.mkdir(exist_ok=True, parents=True)
    
    for fid, rid, xx in zip(file_ids, run_ids, xt):
        num_repeat = len(rid)
        data = np.repeat(xx[np.newaxis, ...], num_repeat, axis=0)
        filename = f"{fid}_xt.npy"
        np.save(dest_path / filename, data)


def save_y0_data(y0: np.ndarray, file_ids: List, run_ids: np.ndarray, dest_path: Union[str, Path]) -> None:
    """
    Save initial conditions data to files.
    
    Args:
        y0: Initial conditions [y_dim]
        file_ids: File identifiers
        run_ids: Run identifiers
        dest_path: Destination directory path
    """
    if not isinstance(dest_path, Path):
        dest_path = Path(dest_path)
    
    dest_path.mkdir(exist_ok=True, parents=True)
    
    for fid, rid, yy in zip(file_ids, run_ids, y0):
        num_repeat = len(rid)
        data = np.repeat(yy[np.newaxis, ...], num_repeat, axis=0)
        filename = f"{fid}_y0.npy"
        np.save(dest_path / filename, data)


def make_x_filename(file_id: Union[int, str, List]) -> str:
    """
    Generate filename for x data.
    
    Args:
        file_id: File identifier
        
    Returns:
        Filename string
    """
    if isinstance(file_id, list):
        fid = file_id[0]
    elif isinstance(file_id, (str, int)):
        fid = int(file_id)
    else:
        raise TypeError(f"Unsupported file_id type: {type(file_id)}")
    
    return f"{fid}_x.npy"


def make_xt_filename(file_id: Union[int, str, List]) -> str:
    """
    Generate filename for temporal features data.
    
    Args:
        file_id: File identifier
        
    Returns:
        Filename string
    """
    if isinstance(file_id, list):
        fid = file_id[0]
    elif isinstance(file_id, (str, int)):
        fid = int(file_id)
    else:
        raise TypeError(f"Unsupported file_id type: {type(file_id)}")
    
    return f"{fid}_xt.npy"


def make_y_filename(file_id: Union[int, str, List], y_filename_suffix: str) -> str:
    """
    Generate filename for y data.
    
    Args:
        file_id: File identifier
        y_filename_suffix: Suffix for y filename
        
    Returns:
        Filename string
    """
    if isinstance(file_id, list):
        fid = file_id[0]
    elif isinstance(file_id, (str, int)):
        fid = int(file_id)
    else:
        raise TypeError(f"Unsupported file_id type: {type(file_id)}")
    
    return f"{fid}_{y_filename_suffix}.npy"


def make_y0_filename(file_id: Union[int, str, List]) -> str:
    """
    Generate filename for initial conditions data.
    
    Args:
        file_id: File identifier
        
    Returns:
        Filename string
    """
    if isinstance(file_id, list):
        fid = file_id[0]
    elif isinstance(file_id, (str, int)):
        fid = int(file_id)
    else:
        raise TypeError(f"Unsupported file_id type: {type(file_id)}")
    
    return f"{fid}_y0.npy"


def save_model_checkpoint(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    loss: float,
    filepath: Union[str, Path],
    **kwargs
) -> None:
    """
    Save model checkpoint with additional metadata.
    
    Args:
        model: Model to save
        optimizer: Optimizer state
        epoch: Current epoch
        loss: Current loss
        filepath: Path to save checkpoint
        **kwargs: Additional metadata to save
    """
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss,
        **kwargs
    }
    
    filepath = Path(filepath)
    filepath.parent.mkdir(parents=True, exist_ok=True)
    torch.save(checkpoint, filepath)


def save_training_history(
    history: Dict[str, List],
    filepath: Union[str, Path]
) -> None:
    """
    Save training history to JSON file.
    
    Args:
        history: Training history dictionary
        filepath: Path to save history
    """
    filepath = Path(filepath)
    filepath.parent.mkdir(parents=True, exist_ok=True)
    
    # Convert numpy arrays to lists for JSON serialization
    json_history = {}
    for key, value in history.items():
        if isinstance(value, np.ndarray):
            json_history[key] = value.tolist()
        elif isinstance(value, list) and len(value) > 0 and isinstance(value[0], np.ndarray):
            json_history[key] = [v.tolist() for v in value]
        else:
            json_history[key] = value
    
    with open(filepath, 'w') as f:
        json.dump(json_history, f, indent=2)


def save_predictions(
    predictions: np.ndarray,
    targets: np.ndarray,
    filepath: Union[str, Path],
    metadata: Optional[Dict[str, Any]] = None
) -> None:
    """
    Save model predictions and targets.
    
    Args:
        predictions: Model predictions
        targets: Ground truth targets
        filepath: Path to save predictions
        metadata: Additional metadata to save
    """
    filepath = Path(filepath)
    filepath.parent.mkdir(parents=True, exist_ok=True)
    
    data = {
        "predictions": predictions,
        "targets": targets
    }
    
    if metadata is not None:
        data["metadata"] = metadata
    
    np.savez(filepath, **data)


def save_config(
    config: Dict[str, Any],
    filepath: Union[str, Path]
) -> None:
    """
    Save configuration to JSON file.
    
    Args:
        config: Configuration dictionary
        filepath: Path to save configuration
    """
    filepath = Path(filepath)
    filepath.parent.mkdir(parents=True, exist_ok=True)
    
    with open(filepath, 'w') as f:
        json.dump(config, f, indent=2)


def load_config(filepath: Union[str, Path]) -> Dict[str, Any]:
    """
    Load configuration from JSON file.
    
    Args:
        filepath: Path to configuration file
        
    Returns:
        Configuration dictionary
    """
    with open(filepath, 'r') as f:
        return json.load(f)


def save_results(
    results: Dict[str, Any],
    filepath: Union[str, Path]
) -> None:
    """
    Save experiment results to file.
    
    Args:
        results: Results dictionary
        filepath: Path to save results
    """
    filepath = Path(filepath)
    filepath.parent.mkdir(parents=True, exist_ok=True)
    
    # Use pickle for complex data structures
    with open(filepath, 'wb') as f:
        pickle.dump(results, f)


def load_results(filepath: Union[str, Path]) -> Dict[str, Any]:
    """
    Load experiment results from file.
    
    Args:
        filepath: Path to results file
        
    Returns:
        Results dictionary
    """
    with open(filepath, 'rb') as f:
        return pickle.load(f)


def save_active_learning_logs(
    logs: List[Dict[str, Any]],
    filepath: Union[str, Path]
) -> None:
    """
    Save active learning logs to JSON file.
    
    Args:
        logs: List of log entries
        filepath: Path to save logs
    """
    filepath = Path(filepath)
    filepath.parent.mkdir(parents=True, exist_ok=True)
    
    with open(filepath, 'w') as f:
        json.dump(logs, f, indent=2)


def load_active_learning_logs(filepath: Union[str, Path]) -> List[Dict[str, Any]]:
    """
    Load active learning logs from JSON file.
    
    Args:
        filepath: Path to logs file
        
    Returns:
        List of log entries
    """
    with open(filepath, 'r') as f:
        return json.load(f)


def create_experiment_directory(
    base_path: Union[str, Path],
    experiment_name: str,
    create_subdirs: bool = True
) -> Path:
    """
    Create directory structure for an experiment.
    
    Args:
        base_path: Base directory path
        experiment_name: Name of the experiment
        create_subdirs: Whether to create subdirectories
        
    Returns:
        Path to experiment directory
    """
    base_path = Path(base_path)
    exp_dir = base_path / experiment_name
    exp_dir.mkdir(parents=True, exist_ok=True)
    
    if create_subdirs:
        (exp_dir / "checkpoints").mkdir(exist_ok=True)
        (exp_dir / "logs").mkdir(exist_ok=True)
        (exp_dir / "results").mkdir(exist_ok=True)
        (exp_dir / "configs").mkdir(exist_ok=True)
        (exp_dir / "plots").mkdir(exist_ok=True)
    
    return exp_dir


def cleanup_old_files(
    directory: Union[str, Path],
    pattern: str = "*.npy",
    keep_latest: int = 10
) -> None:
    """
    Clean up old files, keeping only the latest ones.
    
    Args:
        directory: Directory to clean up
        pattern: File pattern to match
        keep_latest: Number of latest files to keep
    """
    directory = Path(directory)
    if not directory.exists():
        return
    
    files = list(directory.glob(pattern))
    if len(files) <= keep_latest:
        return
    
    # Sort by modification time (newest first)
    files.sort(key=lambda x: x.stat().st_mtime, reverse=True)
    
    # Remove old files
    for file in files[keep_latest:]:
        file.unlink()


def get_file_size_mb(filepath: Union[str, Path]) -> float:
    """
    Get file size in megabytes.
    
    Args:
        filepath: Path to file
        
    Returns:
        File size in MB
    """
    return Path(filepath).stat().st_size / (1024 * 1024)


def get_directory_size_mb(directory: Union[str, Path]) -> float:
    """
    Get total size of directory in megabytes.
    
    Args:
        directory: Path to directory
        
    Returns:
        Total size in MB
    """
    total_size = 0
    for file_path in Path(directory).rglob('*'):
        if file_path.is_file():
            total_size += file_path.stat().st_size
    return total_size / (1024 * 1024)
