from npf.jax.typing import *

from abc import abstractmethod

import math

import jax
import numpy as np

from jax import numpy as jnp
from flax import linen as nn

from npf.jax import functional as F
from npf.jax.modules import (
    MLP,
)


__all__ = [
    "RNN",
    "autoregressive",
    "BlockAttention",
    "GMMSetGenerate",
]


def autoregressive(auto_regress_type: str, **kwargs):
    if auto_regress_type == "rnn":
        return RNN()
    elif auto_regress_type == "set_generative":
        return SetGenerate(**kwargs)
    elif auto_regress_type == "gmm_set_generative":
        return GMMSetGenerate(**kwargs)
    elif auto_regress_type == "block_attention":
        return BlockAttention(**kwargs)
    else:
        raise ValueError(f"Unknown AutoRegressType: {auto_regress_type}")


class RNN(nn.Module):
    def setup(self):
        self.ls_1   = nn.LSTMCell()
        self.ls_2   = nn.LSTMCell()

    def __call__(self, r, generate_num, mask=None):
        key = self.make_rng("sample")
        key, sample_key = jax.random.split(key,2)
        r_front_shape, r_tail_shape = r.shape[:-2], r.shape[-1]
        generate_initial = jax.random.normal(key, (*r_front_shape, generate_num, r_tail_shape))
        r_1 = F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2)), F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2)) + jax.random.normal(sample_key, (*r_front_shape, r_tail_shape))
        r_2 = F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2)), F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2)) + jax.random.normal(key, (*r_front_shape, r_tail_shape))
        r_list = []
        for i in range(generate_num):
            r_1, r_ = self.ls_1(r_1, generate_initial[..., i, :])
            r_2, r_ = self.ls_2(r_2, r_)
            r_list.append(r_)
        r_list = jnp.stack(r_list, axis=-2)
        return r_list


class BlockAttention(nn.Module):
    dim_out: int = 128
    num_heads: int = 8

    def setup(self):
        self.masked_selfattn   = MaskedMultiheadSelfAttention(dim_out=self.dim_out, num_heads=self.num_heads)

    def __call__(self, r, generate_num, mask=None):
        key = self.make_rng("sample")
        key, sample_key = jax.random.split(key,2)
        r_front_shape, r_tail_shape = r.shape[:-2], r.shape[-1]
        r_mean = jnp.stack([F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2))] * generate_num, axis=-2)
        generate_initial = r_mean + jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        # generate_initial = jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        r = jnp.concatenate([r, generate_initial], axis=-2)
        r_out = self.masked_selfattn(r, mask=mask, generate_num=generate_num)
        return r_out[..., -generate_num:, :]


class SetGenerate(nn.Module):
    dim_out: int = 128
    num_heads: int = 8

    def setup(self):
        self.isab = ISAB(dim_out=self.dim_out, num_heads=self.num_heads)

    def __call__(self, r, generate_num, mask=None):
        key = self.make_rng("sample")
        key, sample_key = jax.random.split(key,2)
        r_front_shape, r_tail_shape = r.shape[:-2], r.shape[-1]
        r_mean = jnp.stack([F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2))] * generate_num, axis=-2)
        # generate_initial = r_mean + jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        generate_initial = jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        generate_mask = jnp.ones((r.shape[0], generate_num), dtype=jnp.bool_)
        out = self.isab(r, generate_initial, mask, generate_mask)
        return out


class GMMSetGenerate(nn.Module):
    dim_out: int = 128
    num_heads: int = 8
    num_gmm: int = 3

    def setup(self):
        self.isab = ISAB(dim_out=self.dim_out, num_heads=self.num_heads)
        self.means = self.param("means", jax.nn.initializers.normal(),(self.num_gmm,self.dim_out))

    def __call__(self, r, generate_num, mask=None):
        key = self.make_rng("sample")
        key, sample_key = jax.random.split(key,2)
        r_front_shape, r_tail_shape = r.shape[:-2], r.shape[-1]
        r_mean = jnp.stack([F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2))] * generate_num, axis=-2)
        # generate_initial = r_mean + jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        gmm_means = jax.random.choice(sample_key, jnp.arange(self.num_gmm), (*r_front_shape, generate_num,)).reshape(-1)
        gmm_means = self.means[gmm_means]
        generate_initial = jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape)).reshape(-1,r_tail_shape)
        generate_initial += gmm_means
        generate_initial = jnp.reshape(generate_initial, (*r_front_shape, generate_num, r_tail_shape))
        generate_mask = jnp.ones((r.shape[0], generate_num), dtype=jnp.bool_)
        out = self.isab(r, generate_initial, mask, generate_mask)
        return out


class MAB(nn.Module):
    dim_out: int = 128
    num_heads: int = 8

    def setup(self):
        self.fc_q   = nn.Dense(features=self.dim_out)
        self.fc_k   = nn.Dense(features=self.dim_out)
        self.fc_v   = nn.Dense(features=self.dim_out)
        self.fc_out = nn.Dense(features=self.dim_out)
        self.ln1    = nn.LayerNorm()
        self.ln2    = nn.LayerNorm()

    def scatter(self, x):
        return jnp.concatenate(jnp.split(x, self.num_heads, axis=-1), axis=0)

    def gather(self, x):
        return jnp.concatenate(jnp.split(x, self.num_heads, axis=0), axis=-1)

    def attend(self, q, k, v, mask=None):
        q_, k_, v_ = self.scatter(q), self.scatter(k), self.scatter(v)
        A_logits = q_ @ k_.swapaxes(-2, -1) / math.sqrt(self.dim_out)

        if mask is not None:
            mask = jnp.bool_(mask)
            mask = jnp.stack([mask] * q.shape[-2], axis=-2)
            mask = jnp.concatenate([mask] * self.num_heads, axis=0)
            if A_logits.ndim == 4:
                mask = jnp.expand_dims(mask, axis=1)
            # A = jax.nn.softmax(A_logits, where=mask, initial=0, axis=-1)  # TODO: Check below code can be replaced with this.
            A = jax.nn.softmax(jnp.where(mask, A_logits, -float('inf')), axis=-1)
            A = jnp.where(jnp.isnan(A), 0., A)
        else:
            A = jax.nn.softmax(A_logits, axis=-1)

        return self.gather(A @ v_)

    def __call__(self, q, v, mask=None):
        q, k, v = self.fc_q(q), self.fc_k(v), self.fc_v(v)
        out = self.ln1(q + self.attend(q, k, v, mask))
        out = self.ln2(out + nn.relu(self.fc_out(out)))
        return out


class ISAB(nn.Module):
    dim_out: int = 128
    num_heads: int = 8

    def setup(self):
        self.mab0 = MAB(dim_out=self.dim_out, num_heads=self.num_heads)
        self.mab1 = MAB(dim_out=self.dim_out, num_heads=self.num_heads)

    def __call__(self, context, generate_sample, mask_context=None, mask_generate=None):
        h = self.mab0(context, generate_sample, mask_generate)
        return self.mab1(generate_sample, h, mask_context)


class MaskedMultiheadAttention(nn.Module):
    dim_out: int
    num_heads: int = 8

    def setup(self):
        self.fc_q   = nn.Dense(features=self.dim_out)
        self.fc_k   = nn.Dense(features=self.dim_out)
        self.fc_v   = nn.Dense(features=self.dim_out)
        self.fc_out = nn.Dense(features=self.dim_out)
        self.ln1    = nn.LayerNorm()
        self.ln2    = nn.LayerNorm()

    def scatter(self, x):
        return jnp.concatenate(jnp.split(x, self.num_heads, axis=-1), axis=0)

    def gather(self, x):
        return jnp.concatenate(jnp.split(x, self.num_heads, axis=0), axis=-1)

    def attend(self, q, k, v, mask=None, generate_num=None):
        q_, k_, v_ = self.scatter(q), self.scatter(k), self.scatter(v)
        A_logits = q_ @ k_.swapaxes(-2, -1) / math.sqrt(self.dim_out)

        if mask is not None:
            # mask = jnp.bool_(mask)
            # mask = jnp.stack([mask] * (q.shape[-2]-generate_num), axis=-2)
            mask = jnp.concatenate([mask] * self.num_heads, axis=0)
            mask_ul = jnp.tile(jnp.expand_dims(mask, axis=-1), mask.shape[-1]) * jnp.diag(jnp.ones(mask.shape[-1]))
            mask_ll = jnp.stack([mask] * generate_num, axis=-2)
            mask_ur = jnp.swapaxes(jnp.zeros(mask_ll.shape), -1, -2)
            mask_lr = jnp.tril(jnp.ones((*mask_ll.shape[:-2], generate_num, generate_num)))
            attention_mask = jnp.zeros((*mask.shape[:-1], mask.shape[-1] + generate_num, mask.shape[-1] + generate_num))
            attention_mask = attention_mask.at[..., :-generate_num, :-generate_num].set(mask_ul)
            attention_mask = attention_mask.at[..., :-generate_num, -generate_num:].set(mask_ur)
            attention_mask = attention_mask.at[..., -generate_num:, :-generate_num].set(mask_ll)
            attention_mask = attention_mask.at[..., -generate_num:, -generate_num:].set(mask_lr)
            attention_mask = jnp.bool_(attention_mask)
            if A_logits.ndim == 4:
                attention_mask = jnp.expand_dims(attention_mask, axis=1)
            # A = jax.nn.softmax(A_logits, where=mask, initial=0, axis=-1)  # TODO: Check below code can be replaced with this.
            A = jax.nn.softmax(jnp.where(attention_mask, A_logits, -float('inf')), axis=-1)
            A = jnp.where(jnp.isnan(A), 0., A)
        else:
            A = jax.nn.softmax(A_logits, axis=-1)

        return self.gather(A @ v_)

    def __call__(self, q, k, v, mask=None, generate_num=None):
        q, k, v = self.fc_q(q), self.fc_k(k), self.fc_v(v)
        out = self.ln1(q + self.attend(q, k, v, mask, generate_num))
        out = self.ln2(out + nn.relu(self.fc_out(out)))
        return out


class MaskedMultiheadSelfAttention(MaskedMultiheadAttention):
    def __call__(self, q, mask=None, generate_num=None):
        return super().__call__(q, q, q, mask=mask, generate_num=generate_num)
