from __future__ import annotations
import json
from dataclasses import asdict
from datetime import datetime
from pathlib import Path
from typing import Any
import numpy as np


def save_npz(
    data: list[tuple[str, np.ndarray, np.ndarray] | tuple[str, np.ndarray, np.ndarray, dict]],
    config: Any,
    output_path: str | Path,
) -> None:
    """
    Save dataset as npz with trajectory data and human-readable JSON metadata.
    
    NPZ structure:
      - ids: (N,) array of trajectory IDs
      - traj_{id}: (T, D) trajectory array
      - params_{id}: (P,) parameter array
      - meta_{id}: (optional) dictionary as JSON string
    
    Metadata saved separately as {filename}.json for human inspection.
    """
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    ids = []
    save_dict = {}
    
    for item in data:
        if len(item) == 3:
            traj_id, traj, params = item
            meta = {}
        elif len(item) == 4:
            traj_id, traj, params, meta = item
        else:
            raise ValueError("Each item must be (id, traj, params) or (id, traj, params, meta)")
        
        traj_id = str(traj_id)
        ids.append(traj_id)
        save_dict[f"traj_{traj_id}"] = traj
        save_dict[f"params_{traj_id}"] = params
        
        # Store metadata as JSON string, not pickle
        if meta:
            save_dict[f"meta_{traj_id}"] = json.dumps(meta)
    
    save_dict["ids"] = np.array(ids, dtype=str)
    
    np.savez_compressed(str(output_path), **save_dict)
    print(f"Saved {len(ids)} trajectories -> {output_path}", '\n')
    
    metadata = _build_metadata(config, len(ids))
    json_path = output_path.with_suffix(".json")
    json_path.write_text(json.dumps(metadata, indent=2))
    print(f"Metadata -> {json_path}")


def _build_metadata(config: Any, n_samples: int) -> dict:
    """Extract config into clean JSON-serializable dict"""
    
    if hasattr(config, '__dataclass_fields__'):
        config_dict = asdict(config)
    else:
        config_dict = config.__dict__ if hasattr(config, '__dict__') else {}
    
    metadata = {
        "schema_version": "3",
        "timestamp": datetime.now().isoformat(),
        "n_samples": n_samples,
        
        # System parameters
        "experiment": config_dict.get("experiment"),
        "n_dim": config_dict.get("n_dim"),
        
        # Time discretization
        "t_start": config_dict.get("t_start"),
        "t_end": config_dict.get("t_end"),
        "dt": config_dict.get("dt"),
        "subsample_stride": config_dict.get("subsample_stride"),
        "n_timesteps": config_dict.get("n_timesteps"),
        
        # Sampling configuration
        "sampling_mode": config_dict.get("sampling_mode"),
        "crop_length": config_dict.get("crop_length"),
        "transient_cutoff": config_dict.get("transient_cutoff"),
        "crop_validator": config_dict.get("crop_validator"),
        
        # Parameter configuration
        "parameter_mode": config_dict.get("parameter_mode"),
        "param_ranges": config_dict.get("param_ranges"),
        "parameters": (
            [float(p) for p in config_dict["parameters"]] 
            if config_dict.get("parameters") else None
        ),
    }
    
    return {k: v for k, v in metadata.items() if v is not None}