from __future__ import annotations
from typing import Any, Dict, Optional
import torch
import numpy as np
from pathlib import Path
import json
from src.utils import DataAttr


def load_mat(path: str, drop_meta: bool = True) -> Dict[str, Any]:
    """
    Load a MATLAB .mat file (v5/v7 or v7.3) into a Python dict.
    - For v5/v7 files, uses scipy.io.loadmat and converts MATLAB structs/cells
      into nested dicts/lists.
    - For v7.3 (HDF5) files, falls back to h5py and reads the group tree.

    Args:
        path: Path to the .mat file.
        drop_meta: Remove __header__, __version__, __globals__ keys (v5/v7 only).

    Returns:
        dict mapping variable names -> numpy arrays / lists / dicts.
    """
    import numpy as np

    # --- First try SciPy (v5/v7). Will error on v7.3 ---
    try:
        import scipy.io as sio

        data = sio.loadmat(
            path,
            squeeze_me=True,  # collapse singleton dimensions
            struct_as_record=False,  # legacy; ignored if simplify_cells=True
            simplify_cells=True,  # convert structs->dicts, cells->lists
        )
        if drop_meta:
            for k in ("__header__", "__version__", "__globals__"):
                data.pop(k, None)
        return data

    except Exception:
        # --- Fallback for v7.3 (HDF5) ---
        import h5py

        def _read_h5(obj: Any) -> Any:
            if isinstance(obj, h5py.Dataset):
                val = obj[()]
                # Decode bytes to str if necessary
                if isinstance(val, bytes):
                    return val.decode("utf-8", errors="replace")
                return val
            elif isinstance(obj, h5py.Group):
                return {k: _read_h5(obj[k]) for k in obj.keys()}
            else:
                return obj

        with h5py.File(path, "r") as f:
            return {k: _read_h5(f[k]) for k in f.keys()}


def load_idx_dict(path: str | Path) -> dict:
    """Load the dictionary back from JSON."""
    path = Path(path)
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)


def get_sbj_data(path: str, sbj: int, idx_path: str = "all", drop_meta: bool = True):
    data = load_mat(path, drop_meta)["BAV_data"]

    d = np.asarray(data[sbj], dtype=np.float32)  # ensure numeric (N, 5)
    if d.ndim != 2 or d.shape[1] != 5:
        raise ValueError(f"data[{sbj}] has shape {d.shape}, expected (N, 5).")

    if idx_path != "all":
        idx = load_idx_dict(idx_path)[str(sbj)]

        d = d[idx, :]

    y = d[:, 3:4]  # (len_split, 1)
    x = d[:, [4, 0, 1, 2]].copy()  # (len_split, 4)
    x[:, 0] -= 1
    x[:, 1] -= 1

    x = torch.from_numpy(x)  # torch.Size([len_split, 4])
    y = torch.from_numpy(y)  # torch.Size([len_split, 1])

    return x, y


def get_split_data(data_path: str, idx_paths: list, n_sbj: int = 15):
    """
    Loads the splitted **real** BAV data.

    Parameters:
        data_path: a string indicating where to find the .mat file with the data
        idx_path: a list of strings indicating where to find the .json files that specify the splits
        n_sbj: number of human participants in the dataset, defaults to 15

    Outputs:
        x: torch tensor of size (`n_sbj`, `n_splits`, `len_splits`, 4).
        The last dimension represents the following veriables, in order:
            - Response type (can be 0 or 1)
            - Sigma_V level (can be 0, 1, or 2)
            - S_A (can be -15, 10, 5, 0, 5, 10, or 15)
            - S_V (continuous value between -20 and 20)

        y: torch tensor of size (`n_sbj`, `n_splits`, `len_splits`, 1).
        Contains the participants responses (continuous value between -45 and 45)
    """
    indices = []
    for idx in idx_paths:
        indices.append(load_idx_dict(idx))

    n_splits = len(idx_paths)
    len_splits = len(indices[0]["0"])
    x = torch.zeros((n_sbj, n_splits, len_splits, 4))
    y = torch.zeros((n_sbj, n_splits, len_splits, 1))
    for sbj in range(n_sbj):
        for n, split in enumerate(idx_paths):
            x_split, y_split = get_sbj_data(data_path, sbj, split)
            x[sbj, n, :, :] = x_split
            y[sbj, n, :, :] = y_split

    return x, y


class BavTrueDataloader:
    IDX_SPLIT_FILE = [
        "trial_idx_400_split_1.json",
        "trial_idx_400_split_2.json",
    ]
    DATA_FILE = "bav_data.mat"

    def __init__(
        self,
        data_path: str,
        data_file: str = DATA_FILE,
        idx_split_file: list = IDX_SPLIT_FILE,
        num_sbj: int = 15,
        dtype: torch.dtype = torch.float32,
        device: Optional[torch.device] = None,
        randomize_idx_file: str = None,
    ):
        """
        Initialize the data loader with the given parameters.
        """
        self.data_path = data_path
        self.data_file = data_file
        self.idx_split_file = idx_split_file
        self.num_sbj = num_sbj
        self.dtype = dtype
        self.device = device
        self.dataset = type("Dataset", (), {})()  # empty object
        self.dataset.batch_size = num_sbj * len(idx_split_file)
        self.randomize_idx_file = randomize_idx_file

    def __len__(self):
        return self.dataset.batch_size
    
    def load_data(self):
        # Implement data loading logic here

        idx_paths = [
            Path(self.data_path) / idx_file for idx_file in self.idx_split_file
        ]

        x, y = get_split_data(
            data_path=Path(self.data_path) / self.data_file,
            idx_paths=idx_paths,
            n_sbj=self.num_sbj,
        )

        if self.randomize_idx_file is not None:
            rand_idx = load_idx_dict(self.randomize_idx_file)
            x = x[:,:,rand_idx,:]
            y = y[:,:,rand_idx,:]
            print(f"Data points order have been randomized using {self.randomize_idx_file}", flush=True)

        n_sub, n_split, N, Dx = x.shape
        batch_size = n_sub * n_split

        xt = x.view(batch_size, N, Dx).to(device=self.device, dtype=self.dtype)
        yt = y.view(batch_size, N, 1).to(device=self.device, dtype=self.dtype)

        xc = torch.tensor([2.0, 0.0, 0.0, 0.0], device=self.device, dtype=self.dtype)[
            None, None, :
        ].expand(batch_size, 1, 4)
        yc = torch.tensor([0.0], device=self.device, dtype=self.dtype)[
            None, None, :
        ].expand(batch_size, 1, 1)

        data_list = []
        for i in range(batch_size):
            data_list.append(DataAttr(xc=xc[i:i+1], yc=yc[i:i+1], xb=None, yb=None, 
                                      xt=xt[i:i+1], yt=yt[i:i+1]))
        return data_list