from __future__ import annotations
from typing import Any, Dict, Optional
import torch
import numpy as np
from pathlib import Path
import json



def load_mat(path: str, drop_meta: bool = True) -> Dict[str, Any]:
    """
    Load a MATLAB .mat file (v5/v7 or v7.3) into a Python dict.

    - For v5/v7 files, uses scipy.io.loadmat and converts MATLAB structs/cells
      into nested dicts/lists.
    - For v7.3 (HDF5) files, falls back to h5py and reads the group tree.

    Args:
        path: Path to the .mat file.
        drop_meta: Remove __header__, __version__, __globals__ keys (v5/v7 only).

    Returns:
        dict mapping variable names -> numpy arrays / lists / dicts.
    """
    import numpy as np

    # --- First try SciPy (v5/v7). Will error on v7.3 ---
    try:
        import scipy.io as sio

        data = sio.loadmat(
            path,
            squeeze_me=True,          # collapse singleton dimensions
            struct_as_record=False,   # legacy; ignored if simplify_cells=True
            simplify_cells=True       # convert structs->dicts, cells->lists
        )
        if drop_meta:
            for k in ("__header__", "__version__", "__globals__"):
                data.pop(k, None)
        return data

    except Exception:
        # --- Fallback for v7.3 (HDF5) ---
        import h5py

        def _read_h5(obj: Any) -> Any:
            if isinstance(obj, h5py.Dataset):
                val = obj[()]
                # Decode bytes to str if necessary
                if isinstance(val, bytes):
                    return val.decode("utf-8", errors="replace")
                return val
            elif isinstance(obj, h5py.Group):
                return {k: _read_h5(obj[k]) for k in obj.keys()}
            else:
                return obj

        with h5py.File(path, "r") as f:
            return {k: _read_h5(f[k]) for k in f.keys()}
        

def load_idx_dict(path: str | Path) -> dict:
    """Load the dictionary back from JSON."""
    path = Path(path)
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)

def get_sbj_data(path: str, sbj:int, idx_path: str = "all", drop_meta: bool = True):

    data = load_mat(path, drop_meta)['BAV_data']
    
    d = np.asarray(data[sbj], dtype=np.float32)      # ensure numeric (N, 5)
    if d.ndim != 2 or d.shape[1] != 5:
        raise ValueError(f"data[{sbj}] has shape {d.shape}, expected (N, 5).")
    
    if idx_path != "all":
        idx = load_idx_dict(idx_path)[str(sbj)]

        d = d[idx, :]

    y = d[:, 3:4]                                    # (len_split, 1)
    x = d[:, [4, 0, 1, 2]].copy()                    # (len_split, 4)
    x[:, 0] -= 1
    x[:, 1] -= 1


    x = torch.from_numpy(x)                              # torch.Size([len_split, 4])
    y = torch.from_numpy(y)                              # torch.Size([len_split, 1])

    return x, y


def get_split_data(data_path: str, idx_paths: list, n_sbj: int = 15):
    """
    Loads the splitted **real** BAV data. 

    Parameters:

        data_path: a string indicating where to find the .mat file with the data

        idx_path: a list of strings indicating where to find the .json files that specify the splits

        n_sbj: number of human participants in the dataset, defaults to 15

        
    Outputs:

        x: torch tensor of size (`n_sbj`, `n_splits`, `len_splits`, 4). 
        The last dimension represents the following veriables, in order:
            - Response type (can be 0 or 1)
            - Sigma_V level (can be 0, 1, or 2)
            - S_A (can be -15, 10, 5, 0, 5, 10, or 15)
            - S_V (continuous value between -20 and 20)

        y: torch tensor of size (`n_sbj`, `n_splits`, `len_splits`, 1). 
        Contains the participants responses (continuous value between -45 and 45)
    """
    indices = []
    for idx in idx_paths:
        indices.append(load_idx_dict(idx))

    n_splits = len(idx_paths)
    len_splits = len(indices[0]["0"])
    x = torch.zeros((n_sbj, n_splits, len_splits, 4))
    y = torch.zeros((n_sbj, n_splits, len_splits, 1))
    for sbj in range(n_sbj):
        for n, split in enumerate(idx_paths):
            x_split, y_split = get_sbj_data(data_path, sbj, split)
            x[sbj, n, :, :] = x_split
            y[sbj, n, :, :] = y_split

    return x, y
