from npf.jax.typing import *

import numpy as np

from jax import numpy as jnp
from jax.scipy import stats
from flax import linen as nn

from ..utils import npf_io, MultivariateNormalDiag
from ..data import NPData

from npf.jax.models.cnp import CNPBase
from npf.jax.models.canp import CANPBase
from npf.jax import functional as F
from npf.jax.modules import (
    MLP,
    MultiheadAttention,
    MultiheadSelfAttention,
)

from npf.jax.modules.autoregressive_feature import autoregressive


__all__ = [
    "MPNPFeatureBase",
    "MPNPFeature",
    "MPANPFeatureBase",
    "MPANPFeature",
]


class MPNPFeatureMixin(nn.Module):
    """
    Mixins for Autoregressive Conditional Neural Process
    """

    @npf_io(flatten_input=True)
    def log_likelihood(
        self,
        data: NPData,
        *,
        num_samples: int = 1,
        num_generates: int = 40,
        joint: bool = False,
        split_set: bool = False,
        return_aux: bool = False,
    ) -> Union[
        Array,
        Tuple[Array, Array[B, T, R]],
    ]:

        mu, sigma, *aux = self(data, num_samples=num_samples, num_generates=num_generates, return_aux=return_aux, skip_io=True)  # [batch, sample, point, y_dim] x 2, (aux)

        s_y = jnp.expand_dims(data.y, axis=1)                                                       # [batch, 1,      point, y_dim]
        log_prob = MultivariateNormalDiag(mu, sigma).log_prob(s_y)                                  # [batch, sample, point]

        if joint:
            ll = F.masked_sum(log_prob, data.mask, axis=-1, non_mask_axis=1)                        # [batch, sample]
            ll = F.logmeanexp(ll, axis=1) / jnp.sum(data.mask, axis=-1)                             # [batch]

            if split_set:
                ll_ctx = F.masked_sum(log_prob, data.mask_ctx, axis=-1, non_mask_axis=1)            # [batch, sample]
                ll_tar = F.masked_sum(log_prob, data.mask_tar, axis=-1, non_mask_axis=1)            # [batch, sample]
                ll_ctx = F.logmeanexp(ll_ctx, axis=1) / jnp.sum(data.mask_ctx, axis=-1)             # [batch]
                ll_tar = F.logmeanexp(ll_tar, axis=1) / jnp.sum(data.mask_tar, axis=-1)             # [batch]

        else:
            ll_all = F.logmeanexp(log_prob, axis=1)                                                 # [batch, point]
            ll = F.masked_mean(ll_all, data.mask, axis=-1)                                          # [batch]

            if split_set:
                ll_ctx = F.masked_mean(ll_all, data.mask_ctx, axis=-1)                              # [batch]
                ll_tar = F.masked_mean(ll_all, data.mask_tar, axis=-1)                              # [batch]

        ll = jnp.mean(ll)                                                                           # (1)

        if split_set:
            ll_ctx = jnp.mean(ll_ctx)                                                               # (1)
            ll_tar = jnp.mean(ll_tar)                                                               # (1)

            return ll, ll_ctx, ll_tar, *aux    # aux becomes empty if return_aux=False              # (1) x 3, (aux)
        elif return_aux:
            return ll, *aux                                                                         # (1), (aux)
        else:
            return ll                                                                               # (1)

    @npf_io(flatten_input=True)
    def loss(
        self,
        data: NPData,
        *,
        num_samples: int = 1,
        num_generates: int = 40,
        alpha: float = 1.0,
        beta: float = 1.0,
        loss_type: str = "sum_mean_log",    # or "log_mean_exp_sum_log"
        average: bool = False,
        stop_grad: bool = False,
        **kwargs,
    ) -> Array:

        ll, (train_ll) = self.log_likelihood(                    
            data=data, skip_io=True,
            num_samples=num_samples, num_generates=num_generates,
            joint=True, return_aux=True,
        )
        ll_loss = alpha*ll + beta*train_ll


        # r_ctx = jnp.mean(r_ctx, axis=1)                                                                  # [batch, target, r_dim]
        # mu_base, sigma_base = self._decode(data.x, r_ctx, data.mask)                                  # [batch, target, y_dim] x 2

        # log_prob_base = stats.norm.logpdf(data.y, mu_base, sigma_base)                               # [batch, target, y_dim]
        # ll_base = jnp.sum(log_prob_base, axis=-1)                                                   # [batch, target]
        # # ll_base = F.masked_mean(ll_base, mask_tar)                                                  # (1)
        # ll_base = F.masked_sum(ll_base, data.mask, axis=-1)                                          # [batch]
        # ll_base = jnp.mean(ll_base)                                                                 # (1)

        # s_y_tar = F.repeat_axis(y_tar, num_samples, axis=1)                                         # [batch, sample, target, y_dim]
        # s_x_tar = F.repeat_axis(x_tar, num_samples, axis=1)                                         # [batch, sample, target, x_dim]

        # mu_gen, sigma_gen = self._decode(s_x_tar, r_gen, mask_tar)                                  # [batch, sample, target, y_dim] x 2

        # log_prob_gen = stats.norm.logpdf(s_y_tar, mu_gen, sigma_gen)                                # [batch, sample, target, y_dim]
        # ll_gen = jnp.sum(log_prob_gen, axis=-1)                                                     # [batch, sample, target]

        # if loss_type == "original" or loss_type is None:
        #     ll_gen = F.masked_mean(ll_gen, mask_tar, non_mask_axis=1)                               # (1)
        # # Option 1: sum_mean_log
        # elif loss_type == "sum_mean_log":
        #     ll_gen = F.masked_mean(ll_gen, mask_tar, axis=1, non_mask_axis=1)                       # [batch, target]
        #     ll_gen = F.masked_sum(ll_gen, mask_tar, axis=-1)                                        # [batch]
        #     ll_gen = jnp.mean(ll_gen)                                                               # (1)

        # # Option 2: log_mean_exp_sum_log
        # elif loss_type == "log_mean_exp_sum_log":
        #     ll_gen = F.masked_sum(ll_gen, mask_tar, axis=-1, non_mask_axis=1)                       # [batch, sample]
        #     ll_gen = F.logmeanexp(ll_gen, axis=1)                                                   # [batch]
        #     ll_gen = jnp.mean(ll_gen)                                                               # (1)
        # else:
        #     raise ValueError(f"Unknown loss type: {loss_type}")

        # loss = -(ll + ll_base + ll_gen)                                                             # (1)

        # if average:
        #     loss = loss / jnp.mean(jnp.sum(mask_tar, axis=1))                                       # (1)

        # # return loss                                                                                 # (1)
        # return loss, dict(ll=ll, ll_gen=ll_gen, ll_base=ll_base)                                             # (1)
        return -ll_loss, dict(ll=ll, train_ll=train_ll)


class MPNPFeatureBase(MPNPFeatureMixin, CNPBase):

    auto_regression: nn.Module = None
    encoder:         nn.Module = None
    decoder:         nn.Module = None
    decoder2:        nn.Module = None
    min_sigma:       float = 0.1

    def _decode2(self, r_i_gen):
        print(r_i_gen.shape)
        out = self.decoder2(r_i_gen)
        x_generated_ctx, y_generated_ctx = out[...,:1], out[...,1:2]
        return x_generated_ctx, y_generated_ctx

    def __post_init__(self):
        super().__post_init__()
        if self.auto_regression is None:
            raise ValueError("auto_regression is not specified")

    def _auto_regressive(
        self,
        r_i_ctx:  Array[B, C, R],
        mask_ctx: Array[B, C],
        num_samples: int = 1,
        num_generates: int = 10,
    ) -> Tuple[Array[B, S, G, R], Array[B, G]]:

        s_r_i_ctx = F.repeat_axis(r_i_ctx, num_samples, axis=1)                                     # [batch, sample, context, r_dim]
        r_i_gen   = self.auto_regression(s_r_i_ctx, num_generates, mask_ctx)                        # [batch, sample, generate, r_dim]
        mask_gen = np.ones((mask_ctx.shape[0], num_generates), dtype=mask_ctx.dtype)                # [batch, generate]
        return r_i_gen, mask_gen                                                                    # [batch, sample, generate, r_dim], [batch, generate]

    @nn.compact
    @npf_io(flatten=True)
    def __call__(
        self,
        data:          NPData,
        *,
        num_samples:   int = 1,
        num_generates: int = 10,
        return_aux: bool = False,
        plotting: bool = False,
    ):


        # Algorithm
        s_x_tar = F.repeat_axis(data.x, num_samples, axis=1)                                         # [batch, sample, target, x_dim]

        r_i_ctx = self._encode(data.x_ctx, data.y_ctx, data.mask_ctx)                                              # [batch, context, r_dim]
        r_i_gen, mask_gen = self._auto_regressive(r_i_ctx, data.mask_ctx, num_samples, num_generates)    # [batch, sample, generate, r_dim], [batch, generate]

        s_r_i_ctx = F.repeat_axis(r_i_ctx, num_samples, axis=1)                                     # [batch, sample, context, r_dim]

        r_i_ctx_gen = jnp.concatenate((s_r_i_ctx, r_i_gen), axis=-2)                                # [batch, sample, context + generate, r_dim]
        mask_ctx_gen = jnp.concatenate((data.mask_ctx, mask_gen), axis=-1)                               # [batch, context + generate]

        r_ctx_gen = self._aggregate(s_x_tar, None, r_i_ctx_gen, mask_ctx_gen)                       # [batch, sample, target, r_dim]
        query = jnp.concatenate((s_x_tar, r_ctx_gen), axis=-1)
        mu, sigma = self._decode(query, data.mask)                                      # [batch, sample, target, y_dim] x 2

        # Unflatten and mask
        mu    = F.masked_fill(mu,    data.mask, fill_value=0.,   mask_axis=(0, -2))                  # [batch, sample, target, y_dim]
        sigma = F.masked_fill(sigma, data.mask, fill_value=1e-6, mask_axis=(0, -2))                  # [batch, sample, target, y_dim]
        _x_ctx = F.repeat_axis(data.x_ctx, num_samples, axis=1)
        _y_ctx = F.repeat_axis(data.y_ctx, num_samples, axis=1)
        x_generated_ctx, y_generated_ctx = self._decode2(r_i_gen)
        
        self.sow('intermediates', f'x_generated_ctx', x_generated_ctx)
        self.sow('intermediates', f'y_generated_ctx', y_generated_ctx)

        if plotting:
            self.sow('intermediates', f'x_ctx', data.x_ctx)
            self.sow('intermediates', f'y_ctx', data.y_ctx)
            x_ctx_dec, y_ctx_dec = self._decode2(r_i_ctx)
            self.sow('intermediates', f'x_ctx_dec', x_ctx_dec)
            self.sow('intermediates', f'y_ctx_dec', y_ctx_dec)
            # _x_ctx = F.repeat_axis(data.x_ctx, num_samples, axis=1)
            # _y_ctx = F.repeat_axis(data.y_ctx, num_samples, axis=1)
            # x_generated_ctx, y_generated_ctx = self._decode2(r_i_gen)
            x_ctx_generated_ctx = jnp.concatenate((_x_ctx, x_generated_ctx), axis=-2)
            y_ctx_generated_ctx = jnp.concatenate((_y_ctx, y_generated_ctx), axis=-2)
            r_train_base_ctx = self._aggregate(x_ctx_generated_ctx, x_ctx_generated_ctx, r_i_ctx_gen, mask_ctx_gen)
            query_train = jnp.concatenate((x_ctx_generated_ctx, r_train_base_ctx), axis=-1)
            mu_train, sigma_train = self._decode(query_train, mask_ctx_gen)
            self.sow('intermediates', f'mu_train', mu_train)
            self.sow('intermediates', f'sigma_train', sigma_train)            
            self.sow('intermediates', f'x_ctx_generated_ctx', x_ctx_generated_ctx)
            self.sow('intermediates', f'y_ctx_generated_ctx', y_ctx_generated_ctx)
            self.sow('intermediates', f'mask_ctx_generated_ctx', mask_ctx_gen)

        
        if return_aux:
            # _x_ctx = F.repeat_axis(data.x_ctx, num_samples, axis=1)
            # _y_ctx = F.repeat_axis(data.y_ctx, num_samples, axis=1)
            # x_generated_ctx, y_generated_ctx = self._decode2(r_i_gen)
            x_ctx_generated_ctx = jnp.concatenate((_x_ctx, x_generated_ctx), axis=-2)
            y_ctx_generated_ctx = jnp.concatenate((_y_ctx, y_generated_ctx), axis=-2)
            r_train_base_ctx = self._aggregate(x_ctx_generated_ctx, x_ctx_generated_ctx, r_i_ctx_gen, mask_ctx_gen)
            query_train = jnp.concatenate((x_ctx_generated_ctx, r_train_base_ctx), axis=-1)
            mu_train, sigma_train = self._decode(query_train, mask_ctx_gen)
            log_prob = MultivariateNormalDiag(mu_train, sigma_train).log_prob(y_ctx_generated_ctx)                                  # [batch, sample, point]
            ll = F.masked_sum(log_prob, mask_ctx_gen, axis=-1, non_mask_axis=1)                        # [batch, sample]
            ll = F.logmeanexp(ll, axis=1) / jnp.sum(mask_ctx_gen, axis=-1)                             # [batch]
            ll = jnp.mean(ll) 
            return mu, sigma, (ll)                                                     # [batch, sample, *target, y_dim] x 2, ([batch, target, r_dim], [batch, sample, target, r_dim])
        else:
            return mu, sigma                                                                        # [batch, sample, *target, y_dim] x 2


class MPANPFeatureBase(MPNPFeatureMixin, CANPBase):

    auto_regression: nn.Module = None
    encoder:         nn.Module = None
    self_attention:  Optional[nn.Module] = None
    transform_qk:    nn.Module = None
    cross_attention: nn.Module = None
    decoder:         nn.Module = None
    decoder2:        nn.Module = None
    min_sigma:       float = 0.1

    def _decode2(self, r_i_v_gen):
        out = self.decoder2(r_i_v_gen)
        x_generated_ctx, y_generated_ctx = out[...,:1], out[...,1:2]
        return x_generated_ctx, y_generated_ctx

    def __post_init__(self):
        super().__post_init__()
        if self.transform_qk is None:
            raise ValueError("transform_qk is not specified")
        if self.auto_regression is None:
            raise ValueError("auto_regression is not specified")

    def _aggregate(
        self,
        r_i_q:  Array[B, T, X],
        r_i_k:  Array[B, C, X],
        r_i_v:  Array[B, C, R],
        mask_ctx: Array[B, C],
    ) -> Array[B, T, R]:

        r_ctx = self.cross_attention(r_i_q, r_i_k, r_i_v, mask=mask_ctx)                            # [batch, target, r_dim]
        return r_ctx                                                                                # [batch, target, r_dim]

    def _auto_regressive(
        self,
        r_i_k:  Array[B, C, R],
        r_i_v:  Array[B, C, R],
        mask_kv: Array[B, C],
        num_samples: int = 1,
        num_generates: int = 10,
    ):

        r_i_kv = jnp.concatenate([r_i_k, r_i_v], axis=-1)
        s_r_i_kv = F.repeat_axis(r_i_kv, num_samples, axis=1)                                       # [batch, sample, context, r_dim x 2]

        r_i_gen   = self.auto_regression(s_r_i_kv, num_generates, mask_kv)                          # [batch, sample, generate, r_dim x x2]
        mask_gen = np.ones((mask_kv.shape[0], num_generates), dtype=mask_kv.dtype)                  # [batch, generate]

        r_i_k_gen, r_i_v_gen = jnp.split(r_i_gen, 2, axis=-1)                                       # [batch, sample, generate, r_dim] x 2
        return r_i_k_gen, r_i_v_gen, mask_gen                                                       # [batch, sample, generate, r_dim] x 2, [batch, generate]

    @nn.compact
    @npf_io(flatten=True)
    def __call__(
        self,
        data:          NPData,
        *,
        num_samples:   int = 1,
        num_generates: int = 40,
        return_aux: bool = False,
        plotting: bool = False,
    ):

        # Algorithm
        s_x_tar = F.repeat_axis(data.x, num_samples, axis=1)                                         # [batch, sample, target, x_dim]

        r_i_q_base = self.transform_qk(data.x)                                                       # [batch, target,  r_dim]
        r_i_k_base = self.transform_qk(data.x_ctx)                                                       # [batch, context, r_dim]
        r_i_v_base = self._encode(data.x_ctx, data.y_ctx, data.mask_ctx)                                           # [batch, context, r_dim]

        r_i_k_gen, r_i_v_gen, mask_gen = \
            self._auto_regressive(r_i_k_base, r_i_v_base, data.mask_ctx, num_samples, num_generates)     # [batch, sample, generate, r_dim] x 2, [batch, generate]

        s_r_i_k_base = F.repeat_axis(r_i_k_base, num_samples, axis=1)                               # [batch, sample, context, r_dim]
        r_i_k = jnp.concatenate((s_r_i_k_base, r_i_k_gen), axis=-2)                                 # [batch, sample, context + generate, r_dim]

        s_r_i_v_base = F.repeat_axis(r_i_v_base, num_samples, axis=1)                               # [batch, sample, context, r_dim]
        r_i_v = jnp.concatenate((s_r_i_v_base, r_i_v_gen), axis=-2)                                 # [batch, sample, context + generate, r_dim]
        r_i_q = F.repeat_axis(r_i_q_base, num_samples, axis=1)                                      # [batch, sample, target, r_dim]

        mask_ctx_gen = jnp.concatenate((data.mask_ctx, mask_gen), axis=-1)                               # [batch, context + generate]

        r_ctx_gen = self._aggregate(r_i_q, r_i_k, r_i_v, mask_ctx_gen)                              # [batch, sample, target, r_dim]
        query = jnp.concatenate((s_x_tar, r_ctx_gen), axis=-1)
        mu, sigma = self._decode(query, data.mask)                                      # [batch, sample, target, y_dim] x 2

        # Unflatten and mask
        mu    = F.masked_fill(mu,    data.mask, fill_value=0.,   mask_axis=(0, -2))                  # [batch, sample, target, y_dim]
        sigma = F.masked_fill(sigma, data.mask, fill_value=1e-6, mask_axis=(0, -2))                  # [batch, sample, target, y_dim]

        # self.sow('intermediates', f'r_i_ctx', F.repeat_axis(r_i_v_base, num_samples, axis=1))
        # self.sow('intermediates', f'cr_i_ctx', F.repeat_axis(r_i_k_base, num_samples, axis=1))
        # self.sow('intermediates', f'r_i_generated_ctx', r_i_v_gen)
        # self.sow('intermediates', f'cr_i_generated_ctx', r_i_k_gen)
        # self.sow('intermediates', f'mask_ctx', mask_ctx)
        # self.sow('intermediates', f'mask_generated_ctx', mask_gen)
        _x_ctx = F.repeat_axis(data.x_ctx, num_samples, axis=1)
        _y_ctx = F.repeat_axis(data.y_ctx, num_samples, axis=1)
        x_generated_ctx, y_generated_ctx = self._decode2(r_i_v_gen)

        self.sow('intermediates', f'x_generated_ctx', x_generated_ctx)
        self.sow('intermediates', f'y_generated_ctx', y_generated_ctx)

        if plotting:
            # _x_ctx = F.repeat_axis(data.x_ctx, num_samples, axis=1)
            # _y_ctx = F.repeat_axis(data.y_ctx, num_samples, axis=1)
            # x_generated_ctx, y_generated_ctx = self._decode2(r_i_v_gen)
            x_ctx_generated_ctx = jnp.concatenate((_x_ctx, x_generated_ctx), axis=-2)
            y_ctx_generated_ctx = jnp.concatenate((_y_ctx, y_generated_ctx), axis=-2)
            r_train_base_ctx = self._aggregate(r_i_k, r_i_k, r_i_v, mask_ctx_gen)
            query_train = jnp.concatenate((x_ctx_generated_ctx, r_train_base_ctx), axis=-1)
            mu_train, sigma_train = self._decode(query_train, mask_ctx_gen)
            self.sow('intermediates', f'mu_train', mu_train)
            self.sow('intermediates', f'sigma_train', sigma_train)            
            self.sow('intermediates', f'x_ctx_generated_ctx', x_ctx_generated_ctx)
            self.sow('intermediates', f'y_ctx_generated_ctx', y_ctx_generated_ctx)
            self.sow('intermediates', f'mask_ctx_generated_ctx', mask_ctx_gen)

        
        if return_aux:
            # _x_ctx = F.repeat_axis(data.x_ctx, num_samples, axis=1)
            # _y_ctx = F.repeat_axis(data.y_ctx, num_samples, axis=1)
            # x_generated_ctx, y_generated_ctx = self._decode2(r_i_v_gen)
            x_ctx_generated_ctx = jnp.concatenate((_x_ctx, x_generated_ctx), axis=-2)
            y_ctx_generated_ctx = jnp.concatenate((_y_ctx, y_generated_ctx), axis=-2)
            r_train_base_ctx = self._aggregate(r_i_k, r_i_k, r_i_v, mask_ctx_gen)
            query_train = jnp.concatenate((x_ctx_generated_ctx, r_train_base_ctx), axis=-1)
            mu_train, sigma_train = self._decode(query_train, mask_ctx_gen)
            log_prob = MultivariateNormalDiag(mu_train, sigma_train).log_prob(y_ctx_generated_ctx)                                  # [batch, sample, point]
            ll = F.masked_sum(log_prob, mask_ctx_gen, axis=-1, non_mask_axis=1)                        # [batch, sample]
            ll = F.logmeanexp(ll, axis=1) / jnp.sum(mask_ctx_gen, axis=-1)                             # [batch]
            ll = jnp.mean(ll) 
            return mu, sigma, (ll)                                                     # [batch, sample, *target, y_dim] x 2, ([batch, target, r_dim], [batch, sample, target, r_dim])
        else:
            return mu, sigma                                                                        # [batch, sample, *target, y_dim] x 2


class MPNPFeature:
    def __new__(
        cls,
        y_dim: int,
        r_dim: int = 128,
        encoder_dims: Sequence[int] = (128, 128, 128, 128, 128),
        decoder_dims: Sequence[int] = (128, 128, 128),
        auto_regress_type: str = 'set_generative',
    ):

        return MPNPFeatureBase(
            auto_regression = autoregressive(auto_regress_type),
            encoder = MLP(hidden_features=encoder_dims, out_features=r_dim),
            decoder = MLP(hidden_features=decoder_dims, out_features=(y_dim * 2)),
            decoder2 = MLP(hidden_features=decoder_dims, out_features=(y_dim * 2)),
        )


class MPANPFeature:
    def __new__(
        cls,
        y_dim: int,
        r_dim: int = 128,
        sa_heads: Optional[int] = 8,
        ca_heads: Optional[int] = 8,
        encoder_dims: Sequence[int] = (128, 128, 128, 128, 128),
        decoder_dims: Sequence[int] = (128, 128, 128),
        auto_regress_type: str = 'set_generative',
    ):

        if sa_heads is not None:
            encoder = MLP(hidden_features=encoder_dims, out_features=r_dim, last_activation=True)
            self_attention = MultiheadSelfAttention(dim_out=r_dim, num_heads=sa_heads)
        else:
            encoder = MLP(hidden_features=encoder_dims, out_features=r_dim, last_activation=False)
            self_attention = None

        transform_qk = MLP(hidden_features=encoder_dims, out_features=r_dim, last_activation=False)
        cross_attention = MultiheadAttention(dim_out=r_dim, num_heads=ca_heads)
        decoder = MLP(hidden_features=decoder_dims, out_features=(y_dim * 2))

        auto_regression = autoregressive(auto_regress_type, dim_out=(r_dim * 2))

        return MPANPFeatureBase(
            auto_regression=auto_regression,
            encoder=encoder,
            self_attention=self_attention,
            transform_qk=transform_qk,
            cross_attention=cross_attention,
            decoder=decoder,
            decoder2=MLP(hidden_features=decoder_dims, out_features=(y_dim * 2)),
        )
