from typing import Literal, Optional, Dict

from dataclasses import dataclass

import json
import h5py
from tqdm import tqdm
import torch
import numpy as np

from data.base_data import BaseSample, BaseDataset


@dataclass
class HeatsinkSample(BaseSample):
    """Heatsink mesh sample."""
    pass


class HeatsinkDataset(BaseDataset):
    """Heatsink dataset class."""

    dataset_id = "heatsink"

    def __init__(
        self,
        n_subsampled_nodes: int,
        path: str = "/system/user/publicdata/lcm/heatsink/fixed_splits/heatsink_subsampled.h5",
        splits_path: str = "/system/user/publicdata/lcm/heatsink/fixed_splits/splits.json",
        difficulty: Literal["easy", "medium", "hard"] = "medium",
        split: Literal["train", "val", "test"] = "train",
        domain: Literal["source", "target"] = "source",
        dtype: torch.dtype = torch.float32,
        **kwargs,
    ):
        self.n_subsampled_nodes = n_subsampled_nodes
        super().__init__(
            path=path,
            splits_path=splits_path,
            difficulty=difficulty,
            split=split,
            domain=domain,
            dtype=dtype,
            **kwargs,
        )

    def __getitem__(self, idx: int):
        sample = self.data[idx]
        n_nodes = sample["y"].shape[0]
        if self.n_subsampled_nodes is None or self.n_subsampled_nodes >= n_nodes:
            # no subsampling
            keep_indices = np.arange(n_nodes)
        else:
            keep_indices = np.sort(np.random.choice(n_nodes, self.n_subsampled_nodes, replace=False))
        y = sample["y"][keep_indices, :]
        mesh_coords = sample["mesh_coords"][keep_indices, :]
        y_mesh_coords = sample["y_mesh_coords"][keep_indices, :]
        return HeatsinkSample(
            cond=sample["cond"],
            y=y,
            mesh_coords=mesh_coords,
            y_mesh_coords=y_mesh_coords,
            mesh_edges=sample["edge_index"],
        )

    def _load_sample(
        self,
        cond: np.ndarray,
        mesh_coords: np.ndarray,
        mesh_fields: np.ndarray,
        mesh_material: np.ndarray = None,
    ):
        p_mesh_coords = mesh_coords
        y_mesh_coords = mesh_coords

        # shift all samples to (0,0)
        p_mesh_coords = p_mesh_coords - np.min(p_mesh_coords, axis=0, keepdims=True)
        y_mesh_coords = y_mesh_coords - np.min(y_mesh_coords, axis=0, keepdims=True)


        mesh_edges = np.zeros([1,1])
        dict_out = {
            "cond": cond,
            "y": mesh_fields,
            "mesh_coords": p_mesh_coords,
            "y_mesh_coords": y_mesh_coords,
            "edge_index": mesh_edges,
        }
        if mesh_material is not None:
            dict_out["mesh_material"] = mesh_material[0]
        return dict_out



def get_heatsink_dataset(
    path: str,
    split: str,
    normalization_method: Literal["zscore", "minmax"] = "zscore",
    normalization_stats: Optional[Dict] = None,
    **kwargs,
):
    """Return a configured heatsink dataset by loading it from disk."""
    # source domain
    dataset_source = HeatsinkDataset(
        path=path,
        split=split,
        domain="source",
        **kwargs
    )

    if split == "train":
        normalization_stats = dataset_source.get_normalization_stats(method=normalization_method)
    assert normalization_stats is not None
    dataset_source.normalization_stats = normalization_stats
    dataset_source.normalize(method=normalization_method)

    # taget domain
    dataset_target = HeatsinkDataset(
        path=path,
        split=split,
        domain="target",
        **kwargs
    )

    dataset_target.normalization_stats = normalization_stats
    dataset_target.normalize(method=normalization_method)

    return (dataset_source, dataset_target), normalization_stats
