from npf.jax.typing import *

import math
import jax
from jax import numpy as jnp
from jax import random
from flax import linen as nn
from jax.scipy import stats

from ..utils import npf_io, MultivariateNormalDiag
from ..data import NPData

from npf.jax import functional as F
from npf.jax.models import CNPBase
from npf.jax.models.canp import CANPBase
from npf.jax.modules import (
    MLP,
    MultiheadAttention,
    MultiheadSelfAttention,
)
from npf.jax.modules.autoregressive_direct import (
    autoregressive,
    ISAB,
)


__all__ = [
    "MPANPDirectBase",
    "MPANPDirect",
    "MPNPDirectBase",
    "MPNPDirect",
]


class MPNPDirectMixin(nn.Module):
    """
    Mixins for Autoregressive Conditional Neural Process
    """

    def _decode_2(
        self,
        x_tar:    Array[B, ([M],), T, X],
        r_ctx:    Array[B, ([M],), T, R],
        mask_tar: Array[B, T],
    ) -> Tuple[Array[B, ([M],), T, Y], Array[B, ([M],), T, Y]]:

        query = jnp.concatenate((x_tar, r_ctx), axis=-1)                                            # [batch, (*model), target, x_dim + r_dim]

        query, shape = F.flatten(query, start=0, stop=-2, return_shape=True)                        # [batch x (*model), target, x_dim + y_dim]
        mu_log_sigma = self.decoder(query)                                                          # [batch x (*model), target, y_dim x 2]
        mu_log_sigma = F.unflatten(mu_log_sigma, shape, axis=0)                                     # [batch, (*model), target, y_dim x 2]

        mu, log_sigma = jnp.split(mu_log_sigma, 2, axis=-1)                                         # [batch, (*model), target, y_dim] x 2
        sigma = nn.softplus(log_sigma)                                                              # [batch, (*model), target, y_dim]
        return mu, sigma

    @npf_io(flatten_input=True)
    def log_likelihood(
        self,
        data: NPData,  
        *,
        num_samples:   int = 1,
        num_generates: int = 40,
        joint: bool = False,
        return_aux: bool = False,
        stop_grad: bool = False,
        split_set: bool = False,
    ) -> Union[
        Array,
        Tuple[Array, Array[B, T, R]],
    ]:

        mu, sigma, *aux = self(data, num_samples=num_samples, num_generates=num_generates, return_aux=return_aux, stop_grad=stop_grad, skip_io=True)  # [batch, sample, point, y_dim] x 2, (aux)

        s_y = jnp.expand_dims(data.y, axis=1)                                                       # [batch, 1,      point, y_dim]
        log_prob = MultivariateNormalDiag(mu, sigma).log_prob(s_y)                                  # [batch, sample, point]

        if joint:
            ll = F.masked_sum(log_prob, data.mask, axis=-1, non_mask_axis=1)                        # [batch, sample]
            ll = F.logmeanexp(ll, axis=1) / jnp.sum(data.mask, axis=-1)                             # [batch]

            if split_set:
                ll_ctx = F.masked_sum(log_prob, data.mask_ctx, axis=-1, non_mask_axis=1)            # [batch, sample]
                ll_tar = F.masked_sum(log_prob, data.mask_tar, axis=-1, non_mask_axis=1)            # [batch, sample]
                ll_ctx = F.logmeanexp(ll_ctx, axis=1) / jnp.sum(data.mask_ctx, axis=-1)             # [batch]
                ll_tar = F.logmeanexp(ll_tar, axis=1) / jnp.sum(data.mask_tar, axis=-1)             # [batch]

        else:
            ll_all = F.logmeanexp(log_prob, axis=1)                                                 # [batch, point]
            ll = F.masked_mean(ll_all, data.mask, axis=-1)                                          # [batch]

            if split_set:
                ll_ctx = F.masked_mean(ll_all, data.mask_ctx, axis=-1)                              # [batch]
                ll_tar = F.masked_mean(ll_all, data.mask_tar, axis=-1)                              # [batch]

        ll = jnp.mean(ll)                                                                           # (1)

        if split_set:
            ll_ctx = jnp.mean(ll_ctx)                                                               # (1)
            ll_tar = jnp.mean(ll_tar)                                                               # (1)

            return ll, ll_ctx, ll_tar, *aux    # aux becomes empty if return_aux=False              # (1) x 3, (aux)
        elif return_aux:
            return ll, *aux                                                                         # (1), (aux)
        else:
            return ll                                                                               # (1)

    @npf_io(flatten_input=True)
    def loss(
        self,
        data: NPData,
        *,
        num_samples: int = 1,
        num_generates: int = 40,
        # as_mixture: bool = True,
        loss_type: str = "sum_mean_log",    # or "log_mean_exp_sum_log"
        average: bool = False,
        stop_grad: bool = False,
        **kwargs,
    ) -> Array:

        ll, (train_ll) = self.log_likelihood(                    
            data=data, skip_io=True,
            num_samples=num_samples, num_generates=num_generates,
            joint=True, return_aux=True, stop_grad=stop_grad,
        )
        ll = ll + train_ll


        # r_ctx = jnp.mean(r_ctx, axis=1)                                                                  # [batch, target, r_dim]
        # mu_base, sigma_base = self._decode(data.x, r_ctx, data.mask)                                  # [batch, target, y_dim] x 2

        # log_prob_base = stats.norm.logpdf(data.y, mu_base, sigma_base)                               # [batch, target, y_dim]
        # ll_base = jnp.sum(log_prob_base, axis=-1)                                                   # [batch, target]
        # # ll_base = F.masked_mean(ll_base, mask_tar)                                                  # (1)
        # ll_base = F.masked_sum(ll_base, data.mask, axis=-1)                                          # [batch]
        # ll_base = jnp.mean(ll_base)                                                                 # (1)

        # s_y_tar = F.repeat_axis(y_tar, num_samples, axis=1)                                         # [batch, sample, target, y_dim]
        # s_x_tar = F.repeat_axis(x_tar, num_samples, axis=1)                                         # [batch, sample, target, x_dim]

        # mu_gen, sigma_gen = self._decode(s_x_tar, r_gen, mask_tar)                                  # [batch, sample, target, y_dim] x 2

        # log_prob_gen = stats.norm.logpdf(s_y_tar, mu_gen, sigma_gen)                                # [batch, sample, target, y_dim]
        # ll_gen = jnp.sum(log_prob_gen, axis=-1)                                                     # [batch, sample, target]

        # if loss_type == "original" or loss_type is None:
        #     ll_gen = F.masked_mean(ll_gen, mask_tar, non_mask_axis=1)                               # (1)
        # # Option 1: sum_mean_log
        # elif loss_type == "sum_mean_log":
        #     ll_gen = F.masked_mean(ll_gen, mask_tar, axis=1, non_mask_axis=1)                       # [batch, target]
        #     ll_gen = F.masked_sum(ll_gen, mask_tar, axis=-1)                                        # [batch]
        #     ll_gen = jnp.mean(ll_gen)                                                               # (1)

        # # Option 2: log_mean_exp_sum_log
        # elif loss_type == "log_mean_exp_sum_log":
        #     ll_gen = F.masked_sum(ll_gen, mask_tar, axis=-1, non_mask_axis=1)                       # [batch, sample]
        #     ll_gen = F.logmeanexp(ll_gen, axis=1)                                                   # [batch]
        #     ll_gen = jnp.mean(ll_gen)                                                               # (1)
        # else:
        #     raise ValueError(f"Unknown loss type: {loss_type}")

        # loss = -(ll + ll_base + ll_gen)                                                             # (1)

        # if average:
        #     loss = loss / jnp.mean(jnp.sum(mask_tar, axis=1))                                       # (1)

        # # return loss                                                                                 # (1)
        # return loss, dict(ll=ll, ll_gen=ll_gen, ll_base=ll_base)                                             # (1)
        return -ll

class MPANPDirectBase(CANPBase):

    auto_regress: nn.Module = None
    encoder:         nn.Module = None
    self_attention:  Optional[nn.Module] = None
    transform_qk:    nn.Module = None
    cross_attention: nn.Module = None
    decoder:         nn.Module = None
    min_sigma:       float = 0.1

    def __post_init__(self):
        super().__post_init__()
        if self.transform_qk is None:
            raise ValueError("transform_qk is not specified")
        if self.auto_regress is None:
            raise ValueError("auto_regress is not specified")                                                                  # [batch, (*model), target, y_dim] x 2



class MPNPDirectBase(MPNPDirectMixin, CNPBase):
    """
    Base class of Conditional Neural Process
    """

    encoder: nn.Module = None
    decoder: nn.Module = None
    auto_regress: nn.Module = None
    min_sigma: float = 0.1

    def __post_init__(self):
        super().__post_init__()
        if self.encoder is None:
            raise ValueError("encoder is not specified")
        if self.decoder is None:
            raise ValueError("decoder is not specified")
        if self.auto_regress is None:
            raise ValueError("auto_regress is not specified")


class MPNPDirect:
    """
    Conditional Neural Process
    """

    def __new__(cls,
        y_dim: int,
        r_dim: int = 128,
        encoder_dims: Sequence[int] = (128, 128, 128, 128, 128),
        decoder_dims: Sequence[int] = (128, 128, 128),
        auto_regress_type: str = 'set_generative_6',
    ):
        auto_regress = autoregressive(auto_regress_type, dim_out=2)
        return MPNPDirectBase(
            encoder = MLP(hidden_features=encoder_dims, out_features=r_dim),
            decoder = MLP(hidden_features=decoder_dims, out_features=(y_dim * 2)),
            auto_regress = auto_regress,
        )


class MPANPDirect:
    """
    Attentive Conditional Neural Process
    """

    def __new__(cls,
        y_dim: int,
        r_dim: int = 128,
        sa_heads: Optional[int] = 8,
        ca_heads: Optional[int] = 8,
        encoder_dims: Sequence[int] = (128, 128, 128, 128, 128),
        decoder_dims: Sequence[int] = (128, 128, 128),
        auto_regress_type: str = 'set_generative_6',
    ):

        if sa_heads is not None:
            encoder = MLP(hidden_features=encoder_dims, out_features=r_dim, last_activation=True)
            self_attention = MultiheadSelfAttention(dim_out=r_dim, num_heads=sa_heads)
        else:
            encoder = MLP(hidden_features=encoder_dims, out_features=r_dim, last_activation=False)
            self_attention = None

        cross_attention = MultiheadAttention(dim_out=r_dim, num_heads=ca_heads)
        decoder = MLP(hidden_features=decoder_dims, out_features=(y_dim * 2))
        ##################################################################################
        # auto_regress = autoregressive(auto_regress_type, dim_out=2*r_dim)
        auto_regress = autoregressive(auto_regress_type, dim_out=2, num_heads=8)
        transform_qk = MLP(hidden_features=encoder_dims, out_features=r_dim, last_activation=False)
        ##################################################################################

        return MPANPDirectBase(
            encoder=encoder,
            self_attention=self_attention,
            cross_attention=cross_attention,
            decoder=decoder,
            auto_regress=auto_regress,
            transform_qk=transform_qk,
        )
