from collections.abc import Generator
from pathlib import Path

import h5py
import numpy as np
import requests

from offline.utils.tqdm import tqdm


DATA_ROOT = Path("data")


def download(url: str, **kwargs) -> Generator[bytes, None, None]:
    response = requests.get(url, stream=True, timeout=10)
    total = int(response.headers.get("content-length", 0))
    with tqdm(
        total=total, unit="B", unit_scale=True, unit_divisor=1024, **kwargs
    ) as progress_bar:
        for chunk in response.iter_content(chunk_size=128):
            progress_bar.update(len(chunk))
            yield chunk


def get_keys_from_h5file(h5file: h5py.File) -> list[str]:
    keys = []

    def visitor(name: str, node) -> None:
        if isinstance(node, h5py.Dataset):
            keys.append(name)

    h5file.visititems(visitor)
    return keys


def load_data_dict(path: Path):
    data_dict: dict[str, np.ndarray] = {}
    with h5py.File(path) as data_file:
        for key in tqdm(
            get_keys_from_h5file(data_file),
            desc="load data file",
            leave=False,
        ):
            data_dict[key] = data_file[key][()]  # type: ignore
    return data_dict
