"""
Dataset utilities for converting CFD/FE meshes into graph data.

This module loads samples from disk, extracts node/edge features from Muscat
meshes, and builds train/test datasets for GNN training/evaluation.
"""

import json
import numpy as np
from rich.progress import track
from plaid.containers.sample import Sample

from Muscat.FE.Fields.FEField import FEField
from Muscat.FE.FETools import PrepareFEComputation
from Muscat.Bridges.CGNSBridge import CGNSToMesh

import Muscat.MeshContainers.ElementsDescription as ED
from Muscat.MeshContainers.Filters import FilterObjects as FO
from Muscat.MeshTools import MeshModificationTools as MMT
from Muscat.MeshTools.MeshFieldOperations import GetFieldTransferOp


def tri_cells_to_edges(cells: np.ndarray) -> np.ndarray:
    """
    Build a bidirectional edge list from triangular cell connectivity.

    Args:
        cells: Array of shape (n_cells, 3) containing triangle node indices.

    Returns:
        Array of shape (n_edges*2, 2) containing (sender, receiver) pairs,
        duplicated to include both directions.
    """
    edges = np.vstack([
        cells[:, :2],
        cells[:, 1:],
        cells[:, ::2]
    ])
    receivers = np.min(edges, axis=1)
    senders = np.max(edges, axis=1)
    packed = np.stack([senders, receivers], axis=1).astype(int)
    unique = np.unique(packed, axis=0)
    return np.vstack([unique, unique[:, ::-1]])


def quad_cells_to_edges(cells: np.ndarray) -> np.ndarray:
    """
    Build a bidirectional edge list from quadrilateral cell connectivity.

    Args:
        cells: Array of shape (n_cells, 4) containing quad node indices.

    Returns:
        Array of shape (n_edges*2, 2) containing (sender, receiver) pairs,
        duplicated to include both directions.
    """
    e01 = cells[:, 0:2]
    e12 = cells[:, 1:3]
    e23 = cells[:, 2:4]
    e30 = np.stack((cells[:, 3], cells[:, 0]), axis=1)
    edges = np.vstack([e01, e12, e23, e30])
    receivers = np.min(edges, axis=1)
    senders = np.max(edges, axis=1)
    packed = np.stack([senders, receivers], axis=1).astype(int)
    unique = np.unique(packed, axis=0)
    return np.vstack([unique, unique[:, ::-1]])


def get_distance(mesh, nTag=None):
    """
    Compute a distance to boundary field for all mesh nodes.

    Uses Muscat FE transfer ops to project "skin" positions and returns the
    Euclidean distance from each node to its projected skin position.

    Args:
        mesh: Muscat mesh object.
        nTag: Optional nodal tag name used to select the boundary subset.

    Returns:
        1D numpy array of distances with length = number of nodes.
    """
    MMT.ComputeSkin(mesh, md=None, inPlace=True, skinTagName="Skin")
    dim = int(mesh.GetElementsDimensionality())
    Tspace, Tnumberings, _, _ = PrepareFEComputation(mesh, numberOfComponents=1)
    field_mesh = FEField("", mesh=mesh, space=Tspace, numbering=Tnumberings[0])
    opSkin, _, _ = GetFieldTransferOp(
        inputField=field_mesh,
        targetPoints=mesh.nodes,
        method="Interp/Clamp",
        elementFilter=FO.ElementFilter(dimensionality=dim-1, nTag=nTag),
        verbose=False
    )
    skinpos = opSkin.dot(mesh.nodes)
    return np.linalg.norm(skinpos - mesh.nodes, axis=1)


def get_data(mesh, dataset_name=None):
    """
    Extract node fields, node features, edges, and cell connectivity from a mesh.

    This function:
      - collects all node fields from mesh.nodeFields
      - builds a node_type label from nodal tags
      - computes/loads a distance field (dataset dependent)
      - builds edges from triangle/quad elements (dataset dependent)

    Args:
        mesh: Muscat mesh object.
        dataset_name: Optional dataset identifier (affects tags/element type logic).

    Returns:
        node_fields_dict: dict[str, np.ndarray]
        node_features_dict: dict[str, np.ndarray] containing distance/node_type/mesh_pos
        edges: np.ndarray of shape (n_edges, 2)
        object_ids: np.ndarray of node ids for a tagged object (may be empty)
        cells: np.ndarray of element connectivity
    """
    # automatically retrieve node fields
    node_fields = {key: np.array(value) for key, value in mesh.nodeFields.items()}

    # automatically retrieve nodetags
    tag_label_map = {
        "Airfoil": 1, "Ext_bound": 2, "Inlet": 3,
        "Holes": 1,
        "Bottom": 1, "Top": 2,
        "Intrado": 1, "Extrado": 2, "Inflow": 3,
        "Outflow": 4, "Periodic_1": 5, "Periodic_2": 6
    }

    n_nodes = mesh.GetNumberOfNodes()
    labels = np.zeros(n_nodes, dtype=int)
    for tag_name in mesh.nodesTags:
        t = str(tag_name).split('(')[0].strip()
        ids = mesh.GetNodalTag(t).GetIds()
        labels[ids] = tag_label_map.get(t, 0)
    node_type = labels

    # calculate distance field
    nTag = {
        "2D_Profile": "Airfoil",
        "2D_Multiscale": "Holes",
        "VKI_LS59": "Periodic_2",
        # "VKI_LS59": "Extrado"
    }.get(dataset_name, None)
    if dataset_name == "VKI_LS59":
        distance = node_fields["sdf"]
    else:
        distance = get_distance(mesh, nTag=nTag)

    # select element type and compute edges
    elem_type = ED.Quadrangle_4 if dataset_name == "VKI_LS59" else ED.Triangle_3
    cells = np.array(mesh.elements[elem_type].connectivity, dtype=int)
    if dataset_name == "VKI_LS59":
        edges = quad_cells_to_edges(cells)
    else:
        edges = tri_cells_to_edges(cells)

    # node positions
    mesh_pos = np.array(mesh.nodes)

    # collect object ids
    object_ids = np.array(mesh.GetNodalTag(nTag).GetIds()) if nTag else np.array([])

    node_fields_dict = node_fields
    node_features_dict = {
        'distance': distance,
        'node_type': node_type,
        'mesh_pos': mesh_pos
    }

    return node_fields_dict, node_features_dict, edges, object_ids, cells


def read_indices_from_csv(dataset_path, split_train_name=None, split_test_name=None):
    """
    Read train/test split indices from the dataset split definition.

    Note: Despite the name, this reads a JSON file: problem_definition/split.json

    Args:
        dataset_path: Root dataset directory.
        split_train_name: Key name for training split (default provided by caller).
        split_test_name: Key name for test split (default provided by caller).

    Returns:
        (train_indices, test_indices) as lists of ints.
    """

    file_path = f"{dataset_path}/problem_definition/split.json"
    with open(file_path, 'r') as f:
        data = json.load(f)

    train_indices = data.get(split_train_name, [])
    test_indices = data.get(split_test_name, [])

    return train_indices, test_indices


def load_datasets(dataset_name, dataset_path, split_train_name=None, split_test_name=None, target_field="all_fields"):
    """
    Load train/test datasets and convert each sample mesh into graph ready arrays.

    Args:
        dataset_name: Dataset identifier (controls fields, tags, element type).
        dataset_path: Root dataset directory containing dataset/samples/...
        split_train_name: Split key for training indices.
        split_test_name: Split key for test indices.
        target_field: "all_fields" or a specific field name to extract.

    Returns:
        train_data: dict containing graph inputs/outputs for training samples
        test_data: dict containing graph inputs/outputs for test samples
        train_idx: list of training sample indices
        test_idx: list of test sample indices
    """
    # read indices from csv
    train_idx, test_idx = read_indices_from_csv(dataset_path, split_train_name, split_test_name)

    # scalar input and output definitions
    scalar_input_dict = {
        "VKI_LS59": ["angle_in", "mach_out"],
        "Tensile2D": ["P", "p1", "p2", "p3", "p4", "p5"],
        "2D_Multiscale": ["C11", "C12", "C22"],
        "Rotor37": ["Omega", "P"]
    }
    scalar_output_dict = {
        "VKI_LS59": ["Q", "power", "Pr", "Tr", "eth_is", "angle_out"],
        "Tensile2D": ["max_von_mises", "max_U2_top", "max_sig22_top"],
        "2D_Multiscale": ["effective_energy"],
        "Rotor37": ["Massflow", "Compression_ratio", "Efficiency"]
    }
    # fields expected to be retrieved
    field_names_dict = {
        "VKI_LS59": ['mach', 'nut'],
        "Tensile2D": ['U1', 'U2', 'sig11', 'sig22', 'sig12'],
        "2D_Multiscale": ['u1', 'u2', 'P11', 'P12', 'P22', 'P21', 'psi'],
        "Rotor37": ['Density', 'Pressure', 'Temperature'],
        "2D_Profile": ['Mach', 'Pressure', 'Velocity-x', 'Velocity-y', 'Temperature']
    }

    def process_samples(indices, process_type):
        x_nodes, x_edges, x_tags, x_dists = [], [], [], []
        x_scalars, y_scalars = [], []
        y_fields = []
        object_ids = []
        cells_list = []

        fnames = field_names_dict.get(dataset_name, [])
        in_names = scalar_input_dict.get(dataset_name, [])
        out_names = scalar_output_dict.get(dataset_name, [])

        for i in track(indices, description=f"✅ Processing {process_type}"):
            # load sample
            mesh_data = Sample.load_from_dir(f"{dataset_path}/dataset/samples/sample_{i:09}/")
            tree = mesh_data.get_mesh()
            mesh = CGNSToMesh(tree, baseNames=["Base_2_2"]) if dataset_name=="VKI_LS59" else CGNSToMesh(tree)

            node_fields, node_features, edges, obj_ids, cells = get_data(mesh, dataset_name=dataset_name)

            # determine which fields to retrieve
            if target_field=="all_fields":
                use = fnames
            elif target_field in fnames:
                use = [target_field]
            else:
                raise ValueError("No valid field selected for processing. Please specify a valid target field or use 'all_fields'.")

            # fields
            arrs = [node_fields[fn] for fn in use]
            y_f = np.column_stack(arrs) if arrs else np.zeros((mesh.GetNumberOfNodes(), 0))
            y_fields.append(y_f)

            # scalars
            x_scalars.append([mesh_data.get_scalar(sn) for sn in in_names])
            y_scalars.append([mesh_data.get_scalar(sn) for sn in out_names])

            # graph data
            x_nodes.append(node_features['mesh_pos'])
            x_edges.append(edges)
            x_tags.append(node_features['node_type'])
            x_dists.append(node_features['distance'])
            object_ids.append(obj_ids)
            cells_list.append(cells)

        return {
            "X_nodes": x_nodes,
            "X_edges": x_edges,
            "X_node_tags": x_tags,
            "X_distances": x_dists,
            "X_scalars": np.array(x_scalars, dtype=np.float32),
            "Y_fields": y_fields,
            "Y_scalars": np.array(y_scalars, dtype=np.float32),
            "Object_ids": object_ids,
            "Cells_list": cells_list
        }

    train_data = process_samples(train_idx, "train")
    test_data  = process_samples(test_idx,  "test")

    return train_data, test_data, train_idx, test_idx
