"""Unified parameter data interface.

Provides the Tokens class for consolidating structured parameter handling,
masking, functional inputs, and metadata into a single coherent interface.
"""

from dataclasses import dataclass
from typing import (
    overload,
    Any,
    Dict,
    List,
    Optional,
    Tuple,
    Union,
    Callable,
    Literal
)
from math import prod

import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from jaxtyping import Array

from .flatten import flatten_pytree
from .utils import Independence, Labeller
from .functional_inputs import flatten_functional_inputs
from .reconstruct import decode_pytree

@register_pytree_node_class
@dataclass
class Tokens:
    """
    Unified container for structured token data.

    Stores all parameters and observations in a single flattened array,
    with metadata for efficient decoding to structured format.

    Attributes
    ----------
    data : Array
        All flattened tokens, shape (*sample_shape, n_total_tokens,
        max_batch_size)
    labels : Array
        Integer labels per token, shape (*sample_shape, n_total_tokens)
    partition_idx: int
        Static index which separates conditional and target data
    padding_mask : Optional[Array]
        Padding mask, shape (*sample_shape, n_total_tokens)
    functional_inputs : Optional[Array]
        Functional inputs for tokens, shape (*sample_shape,
        n_total_tokens, max_batch_size)
    """

    data: Array
    labels: Array
    position: Array
    condition: Array
    partition_idx: int
    padding_mask: Optional[Array]
    functional_inputs: Optional[Array]

    @property
    def sample_ndims(self) -> int:
        """
        Get sample shape from data array.

        Returns
        -------
        Tuple[int, ...]
            Shape of sample dimensions
        """
        return len(self.data.shape) - 2  # Remove event and batch

    @property
    def sample_shape(self) -> Tuple[int, ...]:
        """
        Get sample shape from data array.

        Returns
        -------
        Tuple[int, ...]
            Shape of sample dimensions
        """
        return self.data.shape[:self.sample_ndims]

    @overload
    @classmethod
    def from_pytree(
        cls,
        data: Dict[str, Array],
        condition: List[str],
        labeller: Optional[Labeller] = None,
        independence: Optional[Independence] = None,
        functional_inputs: Optional[Dict[str, Array]] = None,
        sample_ndims: int = 1,
        batch_ndims: Optional[Dict[str, int]] = None,
        return_decoder: Literal[False] = False
    ) -> 'Tokens': ...
    @overload
    @classmethod
    def from_pytree(
        cls,
        data: Dict[str, Array],
        condition: List[str],
        labeller: Optional[Labeller] = None,
        independence: Optional[Independence] = None,
        functional_inputs: Optional[Dict[str, Array]] = None,
        sample_ndims: int = 1,
        batch_ndims: Optional[Dict[str, int]] = None,
        return_decoder: Literal[True] = True
    ) -> Union['Tokens', Tuple['Tokens', Callable]]: ...
    @classmethod
    def from_pytree(
        cls,
        data: Dict[str, Array],
        condition: List[str],
        labeller: Optional[Labeller] = None,
        independence: Optional[Independence] = None,
        functional_inputs: Optional[Dict[str, Array]] = None,
        sample_ndims: int = 1,
        batch_ndims: Optional[Dict[str, int]] = None,
        return_decoder: bool = False
    ) -> Union['Tokens', Tuple['Tokens', Callable]]:
        """
        Create Tokens from structured PyTree.

        All keys in data are flattened into a single token array.

        Parameters
        ----------
        data: Dict[str, Array]
            Dictionary of model variable samples. Each array should have shape
            (*sample_dims, *event_dims, *batch_dims).
        condition: List[str]
            List of keys which correspond to conditioning variables
        labeller : Optional[Labeller], optional
            Labeller instance for generating consistent labels across tokens.
            If None, creates a default labeller with sequential indices.
        independence : Optional[Independence], optional
            Independence specification controlling attention patterns between
            tokens. If None, creates default Independence().
        functional_inputs : Optional[Dict[str, Array]], optional
            Dictionary of functional input arrays matching data structure
        sample_ndims : int, optional
            Number of leading sample dimensions. Default is 0.
        batch_ndims : Optional[Dict[str, int]], optional
            Number of trailing batch dimensions for each key.
            If None, defaults to 1 for all keys.
        return_decoder: Optional[bool], optional
            Whether to return a decoding function

        Returns
        -------
        Tokens, optional Callable[Tokens,Tuple[Dict[str,Array],Dict[str,Array]]
            Token object and optional decoding function
        """
        # Default batch_ndims to 1 for all keys
        if batch_ndims is None:
            batch_ndims = {key: 1 for key in data.keys()}

        # Create default labeller if not provided
        if labeller is None:
            labeller = Labeller.for_keys(list(data.keys()))

        # Create default independence if not provided
        if independence is None:
            independence = Independence()

        # Flatten the PyTree
        # Sort data such that conditioning variables come first
        key_order = sorted(data.keys(), key=lambda k: k not in condition)
        data = { k: data[k] for k in key_order }
        flat_data, slices = flatten_pytree(
            data,
            sample_ndims=sample_ndims,
            batch_ndims=batch_ndims
        )
        partition_idx = next(
            s.offset
            for k, s in slices.items()
            if k not in condition
        )

        # Build labels array
        total_tokens = flat_data.shape[sample_ndims]
        sample_shape = flat_data.shape[:sample_ndims]

        # Generate labels using Labeller
        labels_1d = labeller.label(slices)
        broadcast_shape = sample_shape + (total_tokens,)
        labels = jnp.broadcast_to(
            labels_1d.reshape((1,) * sample_ndims + (total_tokens,)),
            broadcast_shape
        )

        # Flatten functional inputs if provided
        func_inputs_flat = None
        if functional_inputs is not None:
            func_inputs_flat = flatten_functional_inputs(
                functional_inputs,
                slices,
                sample_ndims=sample_ndims
            )
            
        position = jnp.concatenate([
            jnp.arange(prod(s.event_shape))
            for s in slices.values()
        ])
        position = jnp.broadcast_to(
            position.reshape((1,) * sample_ndims + (total_tokens,)),
            broadcast_shape
        )

        condition_values = jnp.concatenate([
            jnp.full(
                (prod(v.event_shape),),
                int(k in condition),
                dtype=float
            )
            for k, v in slices.items()
        ])
        condition_values = jnp.broadcast_to(
            condition_values.reshape((1,) * sample_ndims + (total_tokens,)),
            broadcast_shape
        )

        tokens = cls(
            data=flat_data,
            labels=labels,
            position=position,
            condition=condition_values,
            padding_mask=None,
            functional_inputs=func_inputs_flat,
            partition_idx=partition_idx
        )

        if return_decoder:
            def decoder(tokens: Tokens):
                return decode_pytree(
                    tokens.data,
                    slices,
                    tokens.sample_shape,
                    is_subset=False
                )

            return tokens, decoder

        return tokens

    def tree_flatten(self) -> Tuple[Tuple, Dict[str, Any]]:
        """
        Flatten Tokens for JAX PyTree operations.

        Returns
        -------
        Tuple[Tuple, Dict[str, Any]]
            (children, aux_data) where children are arrays with sample
            dimension that get transformed by tree.map, and aux_data
            contains static metadata
        """
        children = (
            self.data,
            self.labels,
            self.position,
            self.condition,
            self.padding_mask,
            self.functional_inputs,
        )
        aux_data = {"partition_idx": self.partition_idx}
        return (children, aux_data)

    @classmethod
    def tree_unflatten(
        cls,
        aux_data: Dict[str, Any],
        children: Tuple
    ) -> 'Tokens':
        """
        Unflatten Tokens from JAX PyTree operations.

        Parameters
        ----------
        aux_data : Dict[str, Any]
            Static metadata
        children : Tuple
            Arrays with sample dimension

        Returns
        -------
        Tokens
            Reconstructed Tokens object
        """
        (
            data,
            labels,
            position,
            condition,
            padding_mask,
            functional_inputs,
        ) = children
        return cls(
            data=data,
            labels=labels,
            position=position,
            condition=condition,
            padding_mask=padding_mask,
            functional_inputs=functional_inputs,
            partition_idx=aux_data["partition_idx"]
        )
