import jax
from jax import numpy as jnp
import numpy as np
from chex import Array, dataclass
from typing import Optional, Tuple, Callable
from functools import partial
from jax.scipy.special import logsumexp
from einshape import jax_einshape as einshape
from jax.tree_util import tree_map
from math import prod

from .cluster import Cluster, QueryCluster

def exact_attention(queries: Array, keys: Array, values: Array, logmass: Optional[Array]) -> Array:
    """
    Compute the attention scores using the exact method.

    Args:
        queries: Shape (batch_size, num_heads, seq_len_q, depth).
        keys: Shape (batch_size, num_heads, seq_len_kv, depth).
        values: Shape (batch_size, num_heads, seq_len_kv, depth).

    Returns:
        Attention output of shape (batch_size, num_heads, seq_len_q, depth).
    """
    # Compute attention scores
    logits = jnp.einsum('...qd,...kd->...qk', queries, keys)
    if logmass is not None:
        logits = logits + logmass
    weights = jax.nn.softmax(logits, axis=-1)
    return jnp.einsum('...qk,...kd->...qd', weights, values)

def attention_step(queries: QueryCluster, clusters: Cluster) -> Cluster:
    """
    Compute a single step of hierarchical attention.
    Args:
        queries: QueryCluster(Shape(Q, D)).
        clusters: Cluster(Shape(K, D)).
    Returns:
        clusters: Cluster(Shape(Q, D)).
    """
    clusters = clusters.batch_tilt(queries)
    clusters = clusters.batch_merge()
    return clusters

def hierarchical_attention(queries: QueryCluster, clusters: Cluster) -> Cluster:
    """
    Compute hierarchical attention.
    Args:
        queries: QueryCluster(Shape([Q], D)).
        clusters: Cluster(Shape([K], D)).
    Returns:
        clusters: Cluster(Shape([Q], D)).
    """
    L = len(queries.logmass.shape)
    assert L > 0, "Queries must have at least one level."
    assert L == len(clusters.logmass.shape), "Queries and clusters must have the same number of levels."
    Ks = clusters.logmass.shape
    Qs = queries.logmass.shape
    #K = clusters.logmass.shape[0]
    #Q = queries.logmass.shape[0]
    #for s in clusters.logmass.shape[1:]:
    #    assert s == K, "All levels of clusters must have the same size."
    #for s in queries.logmass.shape[1:]:
    #    assert s == Q, "All levels of queries must have the same size as clusters."
    coarse_queries = [queries]
    for _ in range(L - 1):
        coarse_queries.append(coarse_queries[-1].batch_merge())
    for i in range(L - 1):
        coarse_queries[i] = coarse_queries[i].replace(
                centroid = coarse_queries[i].centroid - coarse_queries[i+1].centroid[...,None,:],
            )
    total_keys = clusters.logmass.size
    clusters = tree_map(
        lambda x: jnp.reshape(x, (1,) + (prod(Ks),) + x.shape[L:]),
        #lambda x: einshape(f"{L*'k'}...->1({L*'k'})...", x, k=K),
        clusters
    )
    for l, coarse_qs in enumerate(reversed(coarse_queries)):
        coarse_qs = tree_map(
            lambda x: jnp.reshape(x, (prod(Qs[:l]),) + (Qs[l],) + x.shape[l+1:]),
            #lambda x: einshape(f"{l*'q'}q...->({l*'q'})q...", x, q=Q),
            coarse_qs
        )
        clusters = clusters.batch_tilt(coarse_qs)
        clusters = tree_map(
            lambda x: jnp.reshape(x, (prod(Qs[:l+1]),) + (prod(Ks[:L-l-1]),) + (Ks[L-l-1],) + x.shape[3:]),
            #lambda x: einshape(f"({l*'q'})q({(L-l)*'k'})...->({(l+1)*'q'})({(L-l-1)*'k'})k...", x, q=Q, k=K),
            clusters
            )
        clusters = clusters.batch_merge()
    return clusters
