from npf.jax.typing import *

import math
import jax
from jax import numpy as jnp
from jax import random
from flax import linen as nn
from jax.scipy import stats

from ..utils import npf_io, MultivariateNormalDiag
from ..data import NPData

from npf.jax import functional as F
from npf.jax.models import CNPBase
from npf.jax.models.canp import CANPBase
from npf.jax.modules import (
    MLP,
    MultiheadAttention,
    MultiheadSelfAttention,
)
from npf.jax.modules.autoregressive_direct import (
    autoregressive,
    ISAB,
)


__all__ = [
    "MPANPDirectBase",
    "MPANPDirect",
    "MPNPDirectBase",
    "MPNPDirect",
]


class MPNPDirectMixin(nn.Module):
    """
    Mixins for Autoregressive Conditional Neural Process
    """

    def _decode_2(
        self,
        x_tar:    Array[B, ([M],), T, X],
        r_ctx:    Array[B, ([M],), T, R],
        mask_tar: Array[B, T],
    ) -> Tuple[Array[B, ([M],), T, Y], Array[B, ([M],), T, Y]]:

        query = jnp.concatenate((x_tar, r_ctx), axis=-1)                                            # [batch, (*model), target, x_dim + r_dim]

        query, shape = F.flatten(query, start=0, stop=-2, return_shape=True)                        # [batch x (*model), target, x_dim + y_dim]
        mu_log_sigma = self.decoder(query)                                                          # [batch x (*model), target, y_dim x 2]
        mu_log_sigma = F.unflatten(mu_log_sigma, shape, axis=0)                                     # [batch, (*model), target, y_dim x 2]

        mu, log_sigma = jnp.split(mu_log_sigma, 2, axis=-1)                                         # [batch, (*model), target, y_dim] x 2
        sigma = nn.softplus(log_sigma)                                                              # [batch, (*model), target, y_dim]
        return mu, sigma

    @npf_io(flatten_input=True)
    def log_likelihood(
        self,
        data: NPData,  
        *,
        num_samples:   int = 1,
        num_generates: int = 40,
        joint: bool = False,
        return_aux: bool = False,
        stop_grad: bool = False,
        split_set: bool = False,
    ) -> Union[
        Array,
        Tuple[Array, Array[B, T, R]],
    ]:

        mu, sigma, *aux = self(data, num_samples=num_samples, num_generates=num_generates, return_aux=return_aux, stop_grad=stop_grad, skip_io=True)  # [batch, sample, point, y_dim] x 2, (aux)

        s_y = jnp.expand_dims(data.y, axis=1)                                                       # [batch, 1,      point, y_dim]
        log_prob = MultivariateNormalDiag(mu, sigma).log_prob(s_y)                                  # [batch, sample, point]

        if joint:
            ll = F.masked_sum(log_prob, data.mask, axis=-1, non_mask_axis=1)                        # [batch, sample]
            ll = F.logmeanexp(ll, axis=1) / jnp.sum(data.mask, axis=-1)                             # [batch]

            if split_set:
                ll_ctx = F.masked_sum(log_prob, data.mask_ctx, axis=-1, non_mask_axis=1)            # [batch, sample]
                ll_tar = F.masked_sum(log_prob, data.mask_tar, axis=-1, non_mask_axis=1)            # [batch, sample]
                ll_ctx = F.logmeanexp(ll_ctx, axis=1) / jnp.sum(data.mask_ctx, axis=-1)             # [batch]
                ll_tar = F.logmeanexp(ll_tar, axis=1) / jnp.sum(data.mask_tar, axis=-1)             # [batch]

        else:
            ll_all = F.logmeanexp(log_prob, axis=1)                                                 # [batch, point]
            ll = F.masked_mean(ll_all, data.mask, axis=-1)                                          # [batch]

            if split_set:
                ll_ctx = F.masked_mean(ll_all, data.mask_ctx, axis=-1)                              # [batch]
                ll_tar = F.masked_mean(ll_all, data.mask_tar, axis=-1)                              # [batch]

        ll = jnp.mean(ll)                                                                           # (1)

        if split_set:
            ll_ctx = jnp.mean(ll_ctx)                                                               # (1)
            ll_tar = jnp.mean(ll_tar)                                                               # (1)

            return ll, ll_ctx, ll_tar, *aux    # aux becomes empty if return_aux=False              # (1) x 3, (aux)
        elif return_aux:
            return ll, *aux                                                                         # (1), (aux)
        else:
            return ll                                                                               # (1)

    @npf_io(flatten_input=True)
    def loss(
        self,
        data: NPData,
        *,
        num_samples: int = 1,
        num_generates: int = 40,
        # as_mixture: bool = True,
        loss_type: str = "sum_mean_log",    # or "log_mean_exp_sum_log"
        average: bool = False,
        stop_grad: bool = False,
        **kwargs,
    ) -> Array:

        ll, (train_ll) = self.log_likelihood(                    
            data=data, skip_io=True,
            num_samples=num_samples, num_generates=num_generates,
            joint=True, return_aux=True, stop_grad=stop_grad,
        )
        ll = ll + train_ll


        # r_ctx = jnp.mean(r_ctx, axis=1)                                                                  # [batch, target, r_dim]
        # mu_base, sigma_base = self._decode(data.x, r_ctx, data.mask)                                  # [batch, target, y_dim] x 2

        # log_prob_base = stats.norm.logpdf(data.y, mu_base, sigma_base)                               # [batch, target, y_dim]
        # ll_base = jnp.sum(log_prob_base, axis=-1)                                                   # [batch, target]
        # # ll_base = F.masked_mean(ll_base, mask_tar)                                                  # (1)
        # ll_base = F.masked_sum(ll_base, data.mask, axis=-1)                                          # [batch]
        # ll_base = jnp.mean(ll_base)                                                                 # (1)

        # s_y_tar = F.repeat_axis(y_tar, num_samples, axis=1)                                         # [batch, sample, target, y_dim]
        # s_x_tar = F.repeat_axis(x_tar, num_samples, axis=1)                                         # [batch, sample, target, x_dim]

        # mu_gen, sigma_gen = self._decode(s_x_tar, r_gen, mask_tar)                                  # [batch, sample, target, y_dim] x 2

        # log_prob_gen = stats.norm.logpdf(s_y_tar, mu_gen, sigma_gen)                                # [batch, sample, target, y_dim]
        # ll_gen = jnp.sum(log_prob_gen, axis=-1)                                                     # [batch, sample, target]

        # if loss_type == "original" or loss_type is None:
        #     ll_gen = F.masked_mean(ll_gen, mask_tar, non_mask_axis=1)                               # (1)
        # # Option 1: sum_mean_log
        # elif loss_type == "sum_mean_log":
        #     ll_gen = F.masked_mean(ll_gen, mask_tar, axis=1, non_mask_axis=1)                       # [batch, target]
        #     ll_gen = F.masked_sum(ll_gen, mask_tar, axis=-1)                                        # [batch]
        #     ll_gen = jnp.mean(ll_gen)                                                               # (1)

        # # Option 2: log_mean_exp_sum_log
        # elif loss_type == "log_mean_exp_sum_log":
        #     ll_gen = F.masked_sum(ll_gen, mask_tar, axis=-1, non_mask_axis=1)                       # [batch, sample]
        #     ll_gen = F.logmeanexp(ll_gen, axis=1)                                                   # [batch]
        #     ll_gen = jnp.mean(ll_gen)                                                               # (1)
        # else:
        #     raise ValueError(f"Unknown loss type: {loss_type}")

        # loss = -(ll + ll_base + ll_gen)                                                             # (1)

        # if average:
        #     loss = loss / jnp.mean(jnp.sum(mask_tar, axis=1))                                       # (1)

        # # return loss                                                                                 # (1)
        # return loss, dict(ll=ll, ll_gen=ll_gen, ll_base=ll_base)                                             # (1)
        return -ll

class MPANPDirectBase(MPNPDirectMixin, CANPBase):

    auto_regress: nn.Module = None
    encoder:         nn.Module = None
    self_attention:  Optional[nn.Module] = None
    transform_qk:    nn.Module = None
    cross_attention: nn.Module = None
    decoder:         nn.Module = None
    min_sigma:       float = 0.1

    def __post_init__(self):
        super().__post_init__()
        if self.transform_qk is None:
            raise ValueError("transform_qk is not specified")
        if self.auto_regress is None:
            raise ValueError("auto_regress is not specified")                                                                  # [batch, (*model), target, y_dim] x 2


    def _aggregate(
        self,
        x_tar:    Array[B, ([M],), T, X],
        x_ctx:    Array[B, ([M],), C, X],
        r_i_ctx:  Array[B, ([M],), C, R],
        mask_ctx: Array[B, C],
    ) -> Array[B, ([M],), T, R]:

        r_i_q, r_i_k = x_tar, x_ctx
        r_ctx = self.cross_attention(r_i_q, r_i_k, r_i_ctx, mask=mask_ctx)                          # [batch x (*model), target, r_dim]
        return r_ctx

    def _autoregressive(self,
        x_ctx:      Array[B, ..., C, R],
        y_ctx:     Array[B, ..., C, R], #?
        mask:         Array[B, C],
        num_generates: int = 10,
    ) -> Array[B, ..., G, R]:

        ctx = jnp.concatenate([x_ctx, y_ctx], axis=-1)
        generated_ctx = self.auto_regress(ctx, num_generates, mask)
        x_generated_ctx = jnp.split(generated_ctx,2,axis=-1)[0]
        log_sigma = jnp.split(generated_ctx, 2, axis=-1)[1]
        return x_generated_ctx, log_sigma

    @nn.compact
    @npf_io(flatten=True)
    def __call__(self,
        data:          NPData,
        *,
        num_samples:  int = 5,
        num_generates: int = 40,
        alpha: float = 1.0,
        beta: float = 0.1,
        return_aux:  bool = False,
        stop_grad:    bool = False,
        plotting: bool = False,
    ) -> Tuple[Array[B, [T], Y], Array[B, [T], Y]]:

        _x_ctx = F.repeat_axis(data.x_ctx, num_samples, axis=1)                                          # [batch, num_samples, context, x_dim]
        _y_ctx = F.repeat_axis(data.y_ctx, num_samples, axis=1)                                          # [batch, num_samples, context, y_dim]
        _x_tar = F.repeat_axis(data.x, num_samples, axis=1)                                          # [batch, num_samples, target,  x_dim]

        key = self.make_rng("sample")
        key, sample_key = random.split(key)

        x_generated_ctx, log_sigma = self._autoregressive(_x_ctx, _y_ctx, data.mask_ctx, num_generates)
        # mask_generated_ctx = jnp.ones((x_generated_ctx.shape[0], x_generated_ctx.shape[-2]), dtype=data.mask_ctx.dtype)
        num_ctx = random.randint(sample_key, shape=(data.x_ctx.shape[0],), minval=3, maxval=num_generates - 2)
        mask_generated_ctx = jax.vmap(lambda _c:  F.get_mask(num_generates, start=0, stop=_c))(num_ctx)
        cr_i_ctx = self.transform_qk(_x_ctx)
        cr_i_generated_ctx = self.transform_qk(x_generated_ctx)

        r_i_ctx = self._encode(_x_ctx, _y_ctx, data.mask_ctx)                                                       # [batch, context, r_dim]
        generated_rep = self._aggregate(cr_i_generated_ctx, cr_i_ctx, r_i_ctx, data.mask_ctx)
        generated_mu, _ = self._decode_2(x_generated_ctx, generated_rep, mask_generated_ctx)

        if stop_grad is True:
            generated_mu = jax.lax.stop_gradient(generated_mu)
        
        generated_sigma = beta+(1-beta)*nn.softplus(log_sigma)
        y_generated_ctx = generated_mu + jnp.clip(alpha*generated_sigma* random.normal(key, generated_mu.shape), a_min=-1, a_max=1)

        x_ctx_generated_ctx = jnp.concatenate((_x_ctx, x_generated_ctx), axis=-2)
        y_ctx_generated_ctx = jnp.concatenate((_y_ctx, y_generated_ctx), axis=-2)
        mask_ctx_generated_ctx = jnp.concatenate((data.mask_ctx, mask_generated_ctx), axis=-1)

        r_i_ctx_generated_ctx = self._encode(x_ctx_generated_ctx, y_ctx_generated_ctx, mask_ctx_generated_ctx)
        cr_i_ctx_generated_ctx = self.transform_qk(x_ctx_generated_ctx)
        cr_i_tar = self.transform_qk(_x_tar)

        self.sow('intermediates', f'x_generated_ctx', x_generated_ctx)
        self.sow('intermediates', f'y_generated_ctx', y_generated_ctx)

        # Decode
        r_ctx = self._aggregate(cr_i_tar, cr_i_ctx_generated_ctx, r_i_ctx_generated_ctx, mask_ctx_generated_ctx)
        query = jnp.concatenate((_x_tar, r_ctx), axis=-1)
        mu, sigma = self._decode(query, data.mask)                                                   # [batch,  target, y_dim] x 2

        # Unflatten and mask
        mu    = F.masked_fill(mu,    data.mask, fill_value=0.,   mask_axis=(0, -2))                  # [batch, sample, target, y_dim]
        sigma = F.masked_fill(sigma, data.mask, fill_value=1e-6, mask_axis=(0, -2))                  # [batch, sample, target, y_dim]
        if plotting:
            r_train_ctx_generated_ctx = self._aggregate(cr_i_ctx_generated_ctx, cr_i_ctx_generated_ctx, r_i_ctx_generated_ctx, mask_ctx_generated_ctx)
            query_train = jnp.concatenate((x_ctx_generated_ctx, r_train_ctx_generated_ctx), axis=-1)
            mu_train, sigma_train = self._decode(query_train, mask_ctx_generated_ctx)
            self.sow('intermediates', f'mu_train', mu_train)
            self.sow('intermediates', f'sigma_train', sigma_train)            
            self.sow('intermediates', f'x_ctx_generated_ctx', x_ctx_generated_ctx)
            self.sow('intermediates', f'y_ctx_generated_ctx', y_ctx_generated_ctx)
            self.sow('intermediates', f'mask_ctx_generated_ctx', mask_ctx_generated_ctx)


        if return_aux:
            r_train_ctx_generated_ctx = self._aggregate(cr_i_ctx_generated_ctx, cr_i_ctx_generated_ctx, r_i_ctx_generated_ctx, mask_ctx_generated_ctx)
            query_train = jnp.concatenate((x_ctx_generated_ctx, r_train_ctx_generated_ctx), axis=-1)
            mu_train, sigma_train = self._decode(query_train, mask_ctx_generated_ctx)
            log_prob = MultivariateNormalDiag(mu_train, sigma_train).log_prob(y_ctx_generated_ctx)                                  # [batch, sample, point]
            ll = F.masked_sum(log_prob, mask_ctx_generated_ctx, axis=-1, non_mask_axis=1)                        # [batch, sample]
            ll = F.logmeanexp(ll, axis=1) / jnp.sum(mask_ctx_generated_ctx, axis=-1)                             # [batch]
            ll = jnp.mean(ll)                                                                           # (1)
            return mu, sigma, (ll)
        else:
            return mu, sigma


class MPNPDirectBase(MPNPDirectMixin, CNPBase):
    """
    Base class of Conditional Neural Process
    """

    encoder: nn.Module = None
    decoder: nn.Module = None
    auto_regress: nn.Module = None
    min_sigma: float = 0.1

    def __post_init__(self):
        super().__post_init__()
        if self.encoder is None:
            raise ValueError("encoder is not specified")
        if self.decoder is None:
            raise ValueError("decoder is not specified")
        if self.auto_regress is None:
            raise ValueError("auto_regress is not specified")

    def _autoregressive(self,
        x_ctx:      Array[B, ..., C, R],
        y_ctx:     Array[B, ..., C, R], #?
        mask:         Array[B, C],
        num_generates: int = 40,
    ) -> Array[B, ..., G, R]:

        ctx = jnp.concatenate([x_ctx, y_ctx], axis=-1)
        generated_ctx = self.auto_regress(ctx, num_generates, mask)
        x_generated_ctx = jnp.split(generated_ctx,2,axis=-1)[0]
        y_generated_ctx = jnp.split(generated_ctx, 2, axis=-1)[1]
        return x_generated_ctx, y_generated_ctx

    @nn.compact
    @npf_io(flatten=True)
    def __call__(self,
        data:          NPData,
        *,
        num_samples: int = 5,
        num_generates: int = 40,
        alpha: float = 1.0,
        beta: float = 0.1,
        return_aux: bool = False,
        stop_grad:    bool = False,
        plotting:   bool = False,
    ) -> Tuple[Array[B, [T], Y], Array[B, [T], Y]]:

        #######################################################################################
        # Encode
        _x_ctx = F.repeat_axis(data.x_ctx, num_samples, axis=1)
        _y_ctx = F.repeat_axis(data.y_ctx, num_samples, axis=1)
        _x_tar = F.repeat_axis(data.x, num_samples, axis=1)

        key = self.make_rng("sample")
        key, sample_key = random.split(key)

        x_generated_ctx, log_sigma = self._autoregressive(_x_ctx, _y_ctx, data.mask_ctx, num_generates)


        num_ctx = random.randint(sample_key, shape=(data.x_ctx.shape[0],), minval=3, maxval=num_generates - 2)
        mask_generated_ctx = jax.vmap(lambda _c:  F.get_mask(num_generates, start=0, stop=_c))(num_ctx)
        r_i_ctx = self._encode(_x_ctx, _y_ctx, data.mask_ctx)
        generated_rep = self._aggregate(x_generated_ctx, _x_ctx, r_i_ctx, data.mask_ctx)
        generated_mu, _ = self._decode_2(x_generated_ctx, generated_rep, mask_generated_ctx)

        generated_sigma = beta+(1-beta)*nn.softplus(log_sigma)
        
        if stop_grad is True:
            generated_mu = jax.lax.stop_gradient(generated_mu)

        y_generated_ctx = generated_mu + jnp.clip(alpha*generated_sigma* random.normal(key, generated_mu.shape), a_min=-1, a_max=1)

        r_i_generated_ctx = self._encode(x_generated_ctx, y_generated_ctx, mask_generated_ctx)

        r_i_ctx_generated_ctx = jnp.concatenate([r_i_ctx, r_i_generated_ctx], axis=-2)
        mask_ctx_generated_ctx = jnp.concatenate((data.mask_ctx, mask_generated_ctx), axis=-1)
        r_ctx = self._aggregate(_x_tar, data.x_ctx, r_i_ctx_generated_ctx, mask_ctx_generated_ctx)        # [batch, target,  r_dim]

        self.sow('intermediates', f'x_generated_ctx', x_generated_ctx)
        self.sow('intermediates', f'y_generated_ctx', y_generated_ctx)
        # self.sow('intermediates', f'r_i_ctx', r_i_ctx)
        # self.sow('intermediates', f'r_i_generated_ctx', r_i_generated_ctx)
        # self.sow('intermediates', f'mask_ctx', mask_ctx)
        # self.sow('intermediates', f'mask_generated_ctx', mask_generated_ctx)

        # Decode
        query = jnp.concatenate((_x_tar, r_ctx), axis=-1)
        mu, sigma = self._decode(query, data.mask)                                           # [batch,  target, y_dim] x 2

        mu    = F.masked_fill(mu,    data.mask, fill_value=0.,   mask_axis=(0, -2))                   # [batch, target, y_dim]
        sigma = F.masked_fill(sigma, data.mask, fill_value=1e-6, mask_axis=(0, -2))                   # [batch, target, y_dim]

        if plotting:
            x_ctx_generated_ctx = jnp.concatenate((_x_ctx, x_generated_ctx), axis=-2)
            y_ctx_generated_ctx = jnp.concatenate((_y_ctx, y_generated_ctx), axis=-2)
            r_train_base_ctx = self._aggregate(x_ctx_generated_ctx, x_ctx_generated_ctx, r_i_ctx_generated_ctx, mask_ctx_generated_ctx)
            query_train = jnp.concatenate((x_ctx_generated_ctx, r_train_base_ctx), axis=-1)
            mu_train, sigma_train = self._decode(query_train, mask_ctx_generated_ctx)
            self.sow('intermediates', f'mu_train', mu_train)
            self.sow('intermediates', f'sigma_train', sigma_train)            
            self.sow('intermediates', f'x_ctx_generated_ctx', x_ctx_generated_ctx)
            self.sow('intermediates', f'y_ctx_generated_ctx', y_ctx_generated_ctx)
            self.sow('intermediates', f'mask_ctx_generated_ctx', mask_ctx_generated_ctx)

        if return_aux:
            x_ctx_generated_ctx = jnp.concatenate((_x_ctx, x_generated_ctx), axis=-2)
            y_ctx_generated_ctx = jnp.concatenate((_y_ctx, y_generated_ctx), axis=-2)
            r_train_base_ctx = self._aggregate(x_ctx_generated_ctx, x_ctx_generated_ctx, r_i_ctx_generated_ctx, mask_ctx_generated_ctx)
            query_train = jnp.concatenate((x_ctx_generated_ctx, r_train_base_ctx), axis=-1)
            mu_train, sigma_train = self._decode(query_train, mask_ctx_generated_ctx)
            log_prob = MultivariateNormalDiag(mu_train, sigma_train).log_prob(y_ctx_generated_ctx)                                  # [batch, sample, point]
            ll = F.masked_sum(log_prob, mask_ctx_generated_ctx, axis=-1, non_mask_axis=1)                        # [batch, sample]
            ll = F.logmeanexp(ll, axis=1) / jnp.sum(mask_ctx_generated_ctx, axis=-1)                             # [batch]
            ll = jnp.mean(ll) 
            return mu, sigma, (ll)
        else:
            return mu, sigma


class MPNPDirect:
    """
    Conditional Neural Process
    """

    def __new__(cls,
        y_dim: int,
        r_dim: int = 128,
        encoder_dims: Sequence[int] = (128, 128, 128, 128, 128),
        decoder_dims: Sequence[int] = (128, 128, 128),
        auto_regress_type: str = 'set_generative_6',
    ):
        auto_regress = autoregressive(auto_regress_type, dim_out=2)
        return MPNPDirectBase(
            encoder = MLP(hidden_features=encoder_dims, out_features=r_dim),
            decoder = MLP(hidden_features=decoder_dims, out_features=(y_dim * 2)),
            auto_regress = auto_regress,
        )


class MPANPDirect:
    """
    Attentive Conditional Neural Process
    """

    def __new__(cls,
        y_dim: int,
        r_dim: int = 128,
        sa_heads: Optional[int] = 8,
        ca_heads: Optional[int] = 8,
        encoder_dims: Sequence[int] = (128, 128, 128, 128, 128),
        decoder_dims: Sequence[int] = (128, 128, 128),
        auto_regress_type: str = 'set_generative_6',
    ):

        if sa_heads is not None:
            encoder = MLP(hidden_features=encoder_dims, out_features=r_dim, last_activation=True)
            self_attention = MultiheadSelfAttention(dim_out=r_dim, num_heads=sa_heads)
        else:
            encoder = MLP(hidden_features=encoder_dims, out_features=r_dim, last_activation=False)
            self_attention = None

        cross_attention = MultiheadAttention(dim_out=r_dim, num_heads=ca_heads)
        decoder = MLP(hidden_features=decoder_dims, out_features=(y_dim * 2))
        ##################################################################################
        # auto_regress = autoregressive(auto_regress_type, dim_out=2*r_dim)
        auto_regress = autoregressive(auto_regress_type, dim_out=2, num_heads=8)
        transform_qk = MLP(hidden_features=encoder_dims, out_features=r_dim, last_activation=False)
        ##################################################################################

        return MPANPDirectBase(
            encoder=encoder,
            self_attention=self_attention,
            cross_attention=cross_attention,
            decoder=decoder,
            auto_regress=auto_regress,
            transform_qk=transform_qk,
        )
