from typing import Literal
import os.path as osp
from tqdm import tqdm

import numpy as np
from sklearn.neighbors import NearestNeighbors
import torch
import h5py
import json
from huggingface_hub import hf_hub_download
import zipfile


def load_sample(
    cond: np.ndarray,
    mesh_coords: np.ndarray,
    mesh_fields: np.ndarray,
    mesh_material: np.ndarray = None,
):
    p_mesh_coords = mesh_coords[0]
    y_mesh_coords = mesh_coords[1]

    mesh_fields = mesh_fields[1]  # after transformation

    # shift all samples to (0,0)
    p_mesh_coords = p_mesh_coords - np.min(p_mesh_coords, axis=0, keepdims=True)
    y_mesh_coords = y_mesh_coords - np.min(y_mesh_coords, axis=0, keepdims=True)

    # neighbour computation
    nbrs = NearestNeighbors(n_neighbors=5).fit(p_mesh_coords)
    _, indices = nbrs.kneighbors(p_mesh_coords)
    mesh_edges = []
    for i, neighbors in enumerate(indices):
        for neighbor in neighbors:
            mesh_edges.append((i, neighbor))
    mesh_edges = np.array(mesh_edges)
    dict_out = {
        "cond": cond,
        "y": mesh_fields,
        "mesh_coords": p_mesh_coords,
        "y_mesh_coords": y_mesh_coords,
        "edge_index": mesh_edges,
    }
    if mesh_material is not None:
        dict_out["mesh_material"] = mesh_material[0]
    return dict_out


def load_data(
    splits_path: str,
    path: str,
    difficulty: Literal["easy", "medium", "hard"] = "medium",
    split: Literal["train", "val", "test"] = "train",
    domain: Literal["source", "target"] = "source",
    dtype: torch.dtype = torch.float32,
):
    # load metadata
    with open(splits_path, "r") as f:
        splits_metadata = json.load(f)
    if domain == "source":
        domain = "src"
    if domain == "target":
        domain = "tgt"
    data_indices = splits_metadata[difficulty][domain][split]

    # load data
    data = {}
    with h5py.File(path, "r", swmr=True) as h5f:
        channels = {k: v[:] for k, v in h5f["metadata/channels"].items()}
        conds = {k: v[:] for k, v in h5f["metadata/cond"].items()}
        for i, data_index in tqdm(enumerate(data_indices), desc=f"Loading data ({split=}, {domain=})", total=len(data_indices)):
            keys = list(h5f["data"][f"domain{data_index[0]}"].keys())
            sample_args = {
                "cond": h5f["data"][f"domain{data_index[0]}"][f"cond_{str(data_index[1]).zfill(3)}"][:],
                "mesh_coords": h5f["data"][f"domain{data_index[0]}"][f"coords_{str(data_index[1]).zfill(3)}"][:],
                "mesh_fields": h5f["data"][f"domain{data_index[0]}"][f"{str(data_index[1]).zfill(3)}"][:],
                "mesh_material": h5f["data"][f"domain{data_index[0]}"][f"material_{str(data_index[1]).zfill(3)}"][:] if keys[-1].startswith("material") else None,
            }
            sample_results = load_sample(**sample_args)
            sample = {}
            for key, v in sample_results.items():
                if isinstance(v, np.ndarray):
                    sample[key] = torch.from_numpy(v)
                    if "edge_index" in key:
                        sample[key] = sample[key].to(dtype=torch.long)
                    elif "mesh_material" in key:
                        sample[key] = sample[key].to(dtype=torch.long)
                    else:
                        sample[key] = sample[key].to(dtype=dtype)
            data[i] = sample
    # remove coords, make as slices
    channels = {k: slice(c[0], c[-1] + 1) for k, c in channels.items()}
    return data, channels, conds


def download_data(repo_id: str, filename: str, local_dir: str):
        print(f"Downloading dataset from Hugging Face: {repo_id}/{filename}")
        zip_path = hf_hub_download(
            repo_id=repo_id,
            filename=filename,
            repo_type="dataset",
            local_dir=local_dir
        )
        print(f"Extracting zip to: {local_dir}")
        
        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(local_dir)
