import torch
import torch.nn.functional as F
import numpy as np
from typing import Sequence, Union, Tuple, Optional
import os

__all__ = [
    "get_fine2coarse_tensor",
    "fine_logits_to_coarse_prob",
    "fine_pred_to_coarse_pred",
    "hierarchical_consistency_rate",
    "get_hierarchical_label_maps",
]

_DEFAULT_FINE2COARSE = [0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 0, 3]


def get_fine2coarse_tensor(
    device: Union[torch.device, str, None] = None,
    mapping: Optional[Sequence[int]] = None,
) -> torch.Tensor:
    """Return a 1-D `LongTensor` of shape [num_fine] that maps fine id → coarse id.

    Parameters
    ----------
    device : torch.device | str | None
        Device where the tensor will live.
    mapping : Sequence[int] | None
        Optional custom mapping list. If ``None`` the default mapping used in
        PointLiBR (static / opening / furniture / misc) is returned.
    """
    if mapping is None:
        mapping = _DEFAULT_FINE2COARSE
    tensor = torch.as_tensor(mapping, dtype=torch.long)
    if device is not None:
        tensor = tensor.to(device)
    return tensor


def fine_logits_to_coarse_prob(
    fine_logits: torch.Tensor,
    fine2coarse: torch.Tensor,
    num_coarse: int = 4,
) -> torch.Tensor:
    """Convert *logits* over fine classes into *probabilities* over coarse classes.

    Parameters
    ----------
    fine_logits : Tensor
        Shape ``[..., C_f]``.
    fine2coarse : Tensor
        1-D integer tensor of length ``C_f``.
    num_coarse : int
        Total number of coarse classes.
    Returns
    -------
    Tensor
        Probabilities with the same leading dims as ``fine_logits`` and last
        dim ``C_c``.
    """
    orig_shape = fine_logits.shape[:-1]
    C_f = fine_logits.shape[-1]
    assert fine2coarse.numel() == C_f, "Mapping length must equal # fine classes"
    prob_fine = F.softmax(fine_logits, dim=-1)  # [..., C_f]
    one_hot = F.one_hot(fine2coarse, num_classes=num_coarse).float()  # [C_f, C_c]
    prob_coarse = prob_fine @ one_hot  # [..., C_c]
    return prob_coarse.view(*orig_shape, num_coarse)


def fine_pred_to_coarse_pred(
    fine_pred: torch.Tensor,
    fine2coarse: torch.Tensor,
) -> torch.Tensor:
    """Map predicted *labels* from fine granularity to coarse granularity.

    Parameters
    ----------
    fine_pred : Tensor
        Integer tensor of any shape containing fine-level class indices.
    fine2coarse : Tensor
        Mapping tensor produced by :func:`get_fine2coarse_tensor`.
    Returns
    -------
    Tensor
        Integer tensor with same shape as ``fine_pred`` giving coarse class id.
    """
    return fine2coarse[fine_pred]


def hierarchical_consistency_rate(
    fine_pred: torch.Tensor,
    coarse_pred: torch.Tensor,
    fine2coarse: torch.Tensor,
) -> float:
    """Compute proportion of points whose fine prediction agrees with coarse.

    A point is considered *consistent* if ``fine2coarse[fine_pred] == coarse_pred``.
    The function returns the average consistency (scalar float).
    """
    with torch.no_grad():
        consistent = fine2coarse[fine_pred] == coarse_pred
        return float(consistent.float().mean().item()) 


# ---------------------------------------------------------------------
# Helper: dynamically build label-mapping functions for each hierarchy
# ---------------------------------------------------------------------

def get_hierarchical_label_maps(cfg=None, device=None):
    """Return a *list* of callables that map fine labels → labels at each level.

    The last callable will always be identity (fine→fine).  If a coarse
    mapping is available (e.g. from get_fine2coarse_tensor or an external
    CSV), it will be put at index 0.

    Parameters
    ----------
    cfg : EasyConfig | None
        Global config.  If provided and points to a hierarchy matrix list
        (``cfg.dataset.common.get('h_matrix_list_file')``), we will load it
        to construct multi-level mappings.  Otherwise fall back to the
        default 2-level mapping.
    device : torch.device | str | None
        Device of returned tensors.
    """
    mapping_fns = []

    # Try to build from CSV matrices (if provided)
    h_file_list = None
    if cfg is not None:
        h_file_list = cfg.dataset.common.get('h_matrix_list_file', None) if hasattr(cfg, 'dataset') else None

    if h_file_list and os.path.isfile(h_file_list):
        import yaml, numpy as np, os
        with open(h_file_list, 'r') as yf:
            file_list = yaml.load(yf, Loader=yaml.FullLoader)['file_list']
        for f_csv in file_list:
            if not os.path.isabs(f_csv):
                # relative to project root
                f_csv = os.path.join(os.path.dirname(h_file_list), f_csv)
            m = np.loadtxt(f_csv, delimiter=',')  # rows=upper, cols=leaf
            mapping_arr = torch.as_tensor(np.argmax(m, axis=0), dtype=torch.long, device=device)
            fn = lambda labels, arr=mapping_arr: arr[labels]
            mapping_fns.append(fn)

    # Fallback to default coarse↔fine mapping (2-level)
    if not mapping_fns:
        fine2coarse = get_fine2coarse_tensor(device=device)

        def to_coarse(lbl, arr=fine2coarse):
            if lbl.max() >= arr.numel():
                # mismatch → identity mapping
                return lbl
            return arr[lbl]

        mapping_fns = [to_coarse]

    # Always append identity map as finest level
    mapping_fns.append(lambda x: x)

    return mapping_fns 