"""Combine multiple Tokens objects for dataset accumulation."""

import jax.numpy as jnp
from jax import tree
from jaxtyping import Array
from typing import Optional
import dataclasses

from .tokens import Tokens

def combine_tokens(tokens1: Tokens, tokens2: Tokens) -> Tokens:
    """
    Combine two Tokens objects by concatenating samples and padding tokens.

    Parameters
    ----------
    tokens1 : Tokens
        First Tokens object
    tokens2 : Tokens
        Second Tokens object

    Returns
    -------
    Tokens
        Combined Tokens with concatenated samples and padded tokens

    ValueError
        If tokens have incompatible functional_inputs

    ValueError
        If tokens have incompatible sample_ndims

    Notes
    -----
    - Sample dimension (axis 0) is concatenated
    - Token dimension (axis 1) is padded to max across both inputs
    - Padding tokens are applied such that token.partition_idx are consistent
    - Self-attention masks are padded with zeros
    - Padding masks track which tokens are real vs padded
    """
    # Check functional_inputs consistency
    has_func1 = tokens1.functional_inputs is not None
    has_func2 = tokens2.functional_inputs is not None
    if has_func1 != has_func2:
        raise ValueError(
            "Cannot combine tokens: one has functional_inputs, "
            "the other does not"
        )

    # If both have functional_inputs, check final dimension matches
    if has_func1 and has_func2:
        final_dim1 = tokens1.functional_inputs.shape[-1]  # type: ignore
        final_dim2 = tokens2.functional_inputs.shape[-1]  # type: ignore
        if final_dim1 != final_dim2:
            raise ValueError(
                "Cannot combine tokens: functional_inputs have different "
                f"final dimensions ({final_dim1} vs {final_dim2})"
            )

    if tokens1.sample_ndims != tokens2.sample_ndims:
        raise ValueError(
            "Cannot combine tokens with different sample_ndims:"
            f"sample_ndims ({tokens1.sample_ndims} vs {tokens2.sample_ndims})"
        )

    max_n_condition = max(
        tokens1.partition_idx,
        tokens2.partition_idx
    )

    max_n_target = max(
        tokens1.data.shape[tokens1.sample_ndims] - tokens1.partition_idx,
        tokens2.data.shape[tokens2.sample_ndims] - tokens2.partition_idx,
    )

    # Get sample shapes
    sample_ndims = len(tokens1.sample_shape)

    # Add a basic padding mask
    n_tokens1 = tokens1.data.shape[tokens1.sample_ndims]
    n_tokens2 = tokens2.data.shape[tokens2.sample_ndims]
    tokens1.padding_mask = jnp.ones(tokens1.sample_shape + (n_tokens1,))
    tokens2.padding_mask = jnp.ones(tokens2.sample_shape + (n_tokens2,))

    # Create Tokens-like structure for token_ndims to match pytree structure
    # type: ignore for using ints instead of Arrays
    token_ndims_tokens = Tokens(
        data=1, #type: ignore
        labels=1, #type: ignore
        position=1, #type: ignore
        condition=1, #type: ignore
        padding_mask=1, #type: ignore
        functional_inputs=1, #type: ignore
        partition_idx=tokens1.partition_idx
    )

    def split_leaf(leaf, idx, sample_ndims):
        context_idx = jnp.arange(idx)
        target_idx = jnp.arange(idx, leaf.shape[sample_ndims])
        context = jnp.take(leaf, context_idx, axis=sample_ndims)
        target = jnp.take(leaf, target_idx, axis=sample_ndims)
        return context, target

    def pad_token_leaf(leaf1: Optional[Array], leaf2: Optional[Array], ndims: int) -> Optional[Array]:
        if leaf1 is None or leaf2 is None:
            return None
        if max_n_condition == 0:
            return jnp.concatenate([
                _pad_data_to_max_tokens(
                    leaf1,
                    max_n_target,
                    sample_ndims,
                    token_ndims=ndims
                ),
                _pad_data_to_max_tokens(
                    leaf2,
                    max_n_target,
                    sample_ndims,
                    token_ndims=ndims
                ),
            ], axis=0)
        if ndims == 1:
            context_1, target_1 = split_leaf(leaf1, tokens1.partition_idx, sample_ndims)
            context_2, target_2 = split_leaf(leaf2, tokens2.partition_idx, sample_ndims)
            context = jnp.concatenate([
                _pad_data_to_max_tokens(
                    context_1,
                    max_n_condition,
                    sample_ndims,
                    token_ndims=ndims
                ),
                _pad_data_to_max_tokens(
                    context_2,
                    max_n_condition,
                    sample_ndims,
                    token_ndims=ndims
                )
            ], axis=0)
            target = jnp.concatenate([
                _pad_data_to_max_tokens(
                    target_1,
                    max_n_target,
                    sample_ndims,
                    token_ndims=ndims
                ),
                _pad_data_to_max_tokens(
                    target_2,
                    max_n_target,
                    sample_ndims,
                    token_ndims=ndims
                )
            ], axis=0)

            return jnp.concatenate([context, target], axis=sample_ndims)
        if ndims == 2:
            def double_take(m, idx1, idx2):
                """Select submatrix after sample_ndims"""
                return jnp.take(
                    jnp.take(m, idx1, axis=sample_ndims),
                    idx2,
                    axis=sample_ndims + 1
                )

            def submatrices(m, partition):
                """Break m into submatrices by partitioning at the same index in first and second dimension"""
                c_idx = jnp.arange(partition)
                t_idx = jnp.arange(partition, m.shape[sample_ndims])
                c_c = double_take(m, c_idx, c_idx)
                c_t = double_take(m, c_idx, t_idx)
                t_c = double_take(m, t_idx, c_idx)
                t_t = double_take(m, t_idx, t_idx)
                return (c_c, c_t, t_c, t_t)

            def pad_submatrices(sm, partition, n):
                """Pad submatrices based on max_n_condition and max_n_target"""
                n_condition = partition
                n_target = n - partition
                pad_condition = max_n_condition - n_condition
                pad_target = max_n_target - n_target
                (c_c, c_t, t_c, t_t) = sm
                c_c = jnp.pad(
                    c_c,
                    [(0, 0)] * sample_ndims + [(0, pad_condition), (0, pad_condition)],
                    constant_values=0.0
                )
                c_t = jnp.pad(
                    c_t,
                    [(0, 0)] * sample_ndims + [(0, pad_condition), (0, pad_target)],
                    constant_values=0.0
                )
                t_c = jnp.pad(
                    t_c,
                    [(0, 0)] * sample_ndims + [(0, pad_target), (0, pad_condition)],
                    constant_values=0.0
                )
                t_t = jnp.pad(
                    t_t,
                    [(0, 0)] * sample_ndims + [(0, pad_target), (0, pad_target)],
                    constant_values=0.0
                )
                return (c_c, c_t, t_c, t_t)

            def pad_matrix(m, partition):
                sm = submatrices(m, partition)
                padded_sm = pad_submatrices(sm, partition, m.shape[sample_ndims])
                (c_c, c_t, t_c, t_t) = padded_sm
                return jnp.block([[c_c, c_t], [t_c, t_t]])

            padded_leaf_1 = pad_matrix(leaf1, tokens1.partition_idx)
            padded_leaf_2 = pad_matrix(leaf2, tokens2.partition_idx)

            return jnp.concatenate([padded_leaf_1, padded_leaf_2], axis=0)
        else:
            raise NotImplementedError("padding token_ndims > 2 is not supported")

    combined_tokens = tree.map(
        pad_token_leaf,
        tokens1,
        dataclasses.replace(tokens2, partition_idx = tokens1.partition_idx),
        token_ndims_tokens,
        is_leaf=lambda x: x is None
    )
    combined_tokens.partition_idx = max_n_condition
    return combined_tokens

def _pad_data_to_max_tokens(
    data: Array,
    max_n_tokens: int,
    sample_ndims: int,
    pad_value: float = 0.0,
    token_ndims: int = 1
) -> Array:
    """
    Pad data array to max_n_tokens on the token dimension.

    Parameters
    ----------
    data : Array
        Data array with shape (*sample_shape, n_tokens, max_batch_size)
    max_n_tokens : int
        Target number of tokens
    sample_ndims : int
        Number of sample dimensions
    pad_value : float
        Value to use for padding

    Returns
    -------
    Array
        Padded array with shape (*sample_shape, max_n_tokens,
        max_batch_size)
    """
    current_n_tokens = data.shape[sample_ndims]

    if current_n_tokens >= max_n_tokens:
        return data

    # Calculate padding: [(0, 0), ..., (0, pad_amount), (0, 0)]
    pad_width = [(0, 0)] * len(data.shape)

    # Pad on token dimension(s)
    for i in range(token_ndims):
        pad_width[sample_ndims + i] = (0, max_n_tokens - current_n_tokens)

    return jnp.pad(
        data, pad_width, mode='constant', constant_values=pad_value
    )
