import jax
from jax import numpy as jnp
from chex import Array, dataclass
from typing import Tuple, Callable, Optional
from functools import partial
from jax.scipy.special import logsumexp
from einshape import jax_einshape as einshape

@dataclass
class QueryCluster:
    """
    Class representing a cluster of queries in a high-dimensional space.
    """
    logmass: Array # scalar
    centroid: Array # (k,)

    @classmethod
    def make(cls, qs: Array, logmass: Array) -> 'QueryCluster':
        """
        Create a QueryCluster object from the given queries and optional logmass.
        """
        assert qs.ndim == 2, "Queries should be 2D"
        assert logmass.ndim == 1, "Logmass should be 1D"
        assert qs.shape[0] == logmass.shape[0], "Queries and logmass should have the same first dimension"

        weights = jax.nn.softmax(logmass, axis=0)
        new_logmass = logsumexp(logmass, axis=0)
        centroid = jnp.sum(weights[...,None]*qs, axis=0)
        return cls(logmass=new_logmass, centroid=centroid)

    @classmethod
    def make_leaves(cls, qs: Array, labels: Array, label: int, size: int) -> 'QueryCluster':
        """
        Create a batch of QueryCluster objects from the given queries, labels, and size.
        The number of items with the chosen label must be less than or equal to size.
        The resulting batch will have exactly size items.
        """
        assert qs.ndim == 2, "Queries should be 2D"
        assert labels.ndim == 1, "Labels should be 1D"
        assert qs.shape[0] == labels.shape[0], "Queries and labels should have the same first dimension"

        active = (labels == label)
        indices = jnp.argsort(active, descending=True)[:size]
        logmass = jnp.log(active[indices].astype(float))
        centroid = qs[indices]
        return cls(logmass=logmass, centroid=centroid)

    def batch_merge(self):
        """
        Merge batch of queries along last dimension of logmass.
        """
        weights = jax.nn.softmax(self.logmass, axis=-1)
        new_logmass = logsumexp(self.logmass, axis=-1)
        new_centroid = jnp.sum(weights[..., None] * self.centroid, axis=-2)
        return self.replace(logmass=new_logmass, centroid=new_centroid)


@dataclass
class BasicCluster:
    """
    Class representing a leaf cluster of keys and values in a high-dimensional space.
    """
    logmass: Array # scalar
    centroid: Array # (k,)
    value: Array # (v,)

    @classmethod
    def make_leaves(cls, ks: Array, vs: Array, labels: Array, label: int, size: int, qmean: Array) -> 'BasicCluster':
        """
        Create a batch of BasicCluster objects from the given keys, values, and labels.
        The number of items with the chosen label must be less than or equal to size.
        The resulting batch will have exactly size items.
        """
        assert ks.ndim == 2, "Keys should be 2D"
        assert vs.ndim == 2, "Values should be 2D"
        assert ks.shape[0] == vs.shape[0], "Keys and values should have the same first dimension"
        assert labels.ndim == 1, "Labels should be 1D"
        assert labels.shape[0] == ks.shape[0], "Labels and keys should have the same first dimension"
        assert qmean.ndim == 1, "qmean should be 1D"
        assert qmean.shape[0] == ks.shape[1], "qmean should match the second dimension of keys"

        active = (labels == label)
        logmass = jnp.log(active.astype(float)) + jnp.einsum('nk, k->n', ks, qmean)  # Adjust logmass with qmean
        indices = jnp.argsort(active, descending=True)[:size]
        #logmass = jnp.log(active[indices].astype(float))
        logmass = logmass[indices]
        centroid = ks[indices]
        value = vs[indices]
        return cls(
            logmass=logmass,
            centroid=centroid,
            value=value,
        )

    def batch_tilt(self, queries: QueryCluster) -> "BasicCluster":
        assert queries.centroid.ndim == 3, "Query should be 3D"
        assert self.centroid.ndim == 3, "Centroid should be 3D"
        new_logmass = self.logmass[...,None,:] + jnp.einsum('...qd, ...kd->...qk', queries.centroid, self.centroid)
        # NOTE THAT THIS BROADCASTING MAY BE A MASSIVE WASTE OF MEMORY:
        Q = queries.centroid.shape[-2]
        new_centroid_shape = self.centroid.shape[:-2] + (Q,) + self.centroid.shape[-2:]
        new_centroid = jnp.broadcast_to(self.centroid[..., None,:,:], new_centroid_shape)
        new_value_shape = self.value.shape[:-2] + (Q,) + self.value.shape[-2:]
        new_value = jnp.broadcast_to(self.value[..., None, :,:], new_value_shape)
        return self.replace(logmass=new_logmass, centroid=new_centroid, value=new_value)
    
    def batch_merge(self) -> "Cluster":
        """
        Merge batch of clusters along last dimension of logmass.
        """
        weights = jax.nn.softmax(self.logmass, axis=-1)
        new_logmass = logsumexp(self.logmass, axis=-1)
        new_centroid = jnp.sum(weights[..., None] * self.centroid, axis=-2)
        new_value = jnp.sum(weights[..., None] * self.value, axis=-2)
        new_corr = jnp.sum(weights[..., None, None] * (self.value - new_value[...,None,:])[..., :, None] * (self.centroid - new_centroid[...,None,:])[..., None, :], axis=-3)
        new_var = jnp.sum(weights[..., None, None] * (self.centroid - new_centroid[...,None,:])[..., :, None] * (self.centroid - new_centroid[...,None,:])[..., None, :], axis=-3)
        return Cluster(
            logmass=new_logmass,
            centroid=new_centroid,
            value=new_value,
            corr=new_corr,
            var=new_var,
        )


@dataclass
class Cluster:
    """
    Class representing a cluster of keys and values in a high-dimensional space.
    """
    logmass: Array # scalar
    centroid: Array # (k,)
    value: Array # (v,)
    corr: Array # (v, k)
    var: Array # (k, k)

    @classmethod
    def make(cls, ks: Array, vs: Array, logmass: Array) -> 'Cluster':
        """
        Create a Cluster object from the given keys, values, and optional logmass.
        """
        assert ks.ndim == 2, "Keys should be 2D"
        assert vs.ndim == 2, "Values should be 2D"
        assert ks.shape[0] == vs.shape[0], "Keys and values should have the same first dimension"
        assert logmass.shape == (ks.shape[0],), "Logmass should have the same shape as first dimension of keys"

        weights = jax.nn.softmax(logmass, axis=0)
        new_logmass = logsumexp(logmass, axis=0)
        centroid = jnp.einsum('w, wk->k', weights, ks)
        value = jnp.einsum('w, wv->v', weights, vs)
        corr = jnp.einsum('w, wv, wk->vk', weights, vs - value, ks - centroid)
        var = jnp.einsum('w, wk, wl->kl', weights, ks - centroid, ks - centroid)
        return cls(logmass=new_logmass, centroid=centroid, value=value, corr=corr, var=var)

    def batch_tilt(self, queries: QueryCluster):
        """
        Return type has queries in the leading dimension, then clusters.
        """
        assert queries.centroid.ndim == 3, "Query should be 3D"
        assert self.centroid.ndim == 3, "Centroid should be 3D"
        alpha = 0e-1
        new_logmass = self.logmass[...,None,:] + jnp.einsum('...qd, ...kd->...qk', queries.centroid, self.centroid) + 0.5 * alpha * jnp.einsum('...qd, ...qc, ...kdc->...qk', queries.centroid, queries.centroid, self.var)
        # NOTE THAT THIS BROADCASTING MAY BE A MASSIVE WASTE OF MEMORY:
        Q = queries.centroid.shape[-2]
        new_centroid_shape = self.centroid.shape[:-2] + (Q,) + self.centroid.shape[-2:]
        new_centroid = jnp.broadcast_to(self.centroid[..., None,:,:], new_centroid_shape) + alpha * jnp.einsum('...qd, ...kvd->...qkv', queries.centroid, self.var)
        new_value_shape = self.value.shape[:-2] + (Q,) + self.value.shape[-2:]
        new_value = jnp.broadcast_to(self.value[..., None, :,:], new_value_shape) + jnp.einsum('...qd, ...kvd->...qkv', queries.centroid, self.corr)
        new_corr_shape = self.corr.shape[:-3] + (Q,) + self.corr.shape[-3:]
        new_corr = jnp.broadcast_to(self.corr[..., None, :, :, :], new_corr_shape)
        new_var_shape = self.var.shape[:-3] + (Q,) + self.var.shape[-3:]
        new_var = jnp.broadcast_to(self.var[..., None, :, :, :], new_var_shape)

        return self.replace(logmass=new_logmass, centroid=new_centroid, value=new_value, corr=new_corr, var=new_var)

    def batch_merge(self):
        """
        Merge batch of clusters along last dimension of logmass.
        """
        weights = jax.nn.softmax(self.logmass, axis=-1)
        new_logmass = logsumexp(self.logmass, axis=-1)
        new_centroid = jnp.sum(weights[..., None] * self.centroid, axis=-2)
        new_value = jnp.sum(weights[..., None] * self.value, axis=-2)
        new_corr = jnp.sum(weights[..., None, None] * self.corr, axis=-3) + jnp.sum(weights[..., None, None] * (self.value - new_value[...,None,:])[..., :, None] * (self.centroid - new_centroid[...,None,:])[..., None, :], axis=-3)
        new_var = jnp.sum(weights[..., None, None] * self.var, axis=-3) + jnp.sum(weights[..., None, None] * (self.centroid - new_centroid[...,None,:])[..., :, None] * (self.centroid - new_centroid[...,None,:])[..., None, :], axis=-3)

        return self.replace(logmass=new_logmass, centroid=new_centroid, value=new_value, corr=new_corr, var=new_var)

