import json
import os
from types import NoneType
from typing import Any

import numpy as np
import yaml
from pathlib import Path


class NumpyJsonEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super().default(obj)


def get_ndarray_keys_recursively(obj: dict | list, path: list[str | int]):
    """Find the full paths for each numpy array in a nested structure of dicts and lists.
    Return a mapping from ndarray object IDs to the corresponding full path."""
    if isinstance(obj, np.ndarray):
        return {id(obj): path}
    elif isinstance(obj, dict):
        paths = {}
        for key in obj:
            paths.update(get_ndarray_keys_recursively(obj[key], path + [key]))
        return paths
    elif isinstance(obj, list):
        paths = {}
        for key, value in enumerate(obj):
            paths.update(get_ndarray_keys_recursively(value, path + [key]))
        return paths
    else:
        return {}


def _get_numpy_yaml_loader(yaml_path: str | os.PathLike):
    class NumpySafeLoader(yaml.SafeLoader):
        def __init__(self, stream):
            super().__init__(stream)
            self.yaml_path = Path(yaml_path)

    def numpy_array_constructor(loader: NumpySafeLoader, node):
        relative_np_path = loader.construct_scalar(node)
        np_path = Path(yaml_path).parent / relative_np_path
        return np.load(np_path)

    NumpySafeLoader.add_constructor("!ndarray", numpy_array_constructor)
    return NumpySafeLoader


def _get_numpy_yaml_dumper(output_path: str | os.PathLike):
    class NumpySafeDumper(yaml.SafeDumper):
        def __init__(self, *args, **kwargs):
            self.output_path = Path(output_path)
            self.key = {}
            super().__init__(*args, **kwargs)

        def represent(self, data):
            self.keys = get_ndarray_keys_recursively(data, [])
            super().represent(data)

    def numpy_array_representer(dumper: NumpySafeDumper, data: np.ndarray):
        directory = dumper.output_path.parent
        directory.mkdir(parents=True, exist_ok=True)

        current_key = ".".join(str(key) for key in dumper.keys[id(data)])
        relative_np_path = os.path.join(
            f"{dumper.output_path.stem}_np_files", f"{current_key}.npy"
        )

        Path(directory / relative_np_path).parent.mkdir(parents=True, exist_ok=True)
        np.save(directory / relative_np_path, data)

        return dumper.represent_scalar("!ndarray", relative_np_path)

    NumpySafeDumper.add_representer(np.ndarray, numpy_array_representer)
    return NumpySafeDumper


def dump_yaml_with_arrays(data: dict[str, Any], yaml_filepath: str | os.PathLike):
    yaml_filepath = Path(yaml_filepath)
    yaml_filepath.parent.mkdir(parents=True, exist_ok=True)
    with open(yaml_filepath, "w") as yaml_file:
        yaml.dump(
            data, yaml_file, Dumper=_get_numpy_yaml_dumper(output_path=yaml_filepath)
        )


def load_yaml_with_arrays(yaml_filepath: str | os.PathLike) -> dict[str, Any]:
    yaml_filepath = Path(yaml_filepath)
    with open(yaml_filepath, "r") as yaml_file:
        return yaml.load(yaml_file, Loader=_get_numpy_yaml_loader(yaml_filepath))
