from npf.jax.typing import *

from abc import abstractmethod

import math

import jax
import numpy as np

from jax import numpy as jnp
from flax import linen as nn

from npf.jax import functional as F
from npf.jax.modules import (
    MLP,
)


__all__ = [
    "RNN",
    "autoregressive",
    "MaskedMultiHeadAttention",
    "MaskedSelfAttention",
    "BlockAttention",
    "GMMSetGenerate",
]

def autoregressive(auto_regress_type: str, **kwargs):
    if auto_regress_type == "rnn":
        return RNN()
    elif auto_regress_type == "set_generative":
        return SetGenerate(**kwargs)
    elif auto_regress_type == "set_generative_2":
        return SetGenerate2(**kwargs)
    elif auto_regress_type == "set_generative_3":
        return SetGenerate3(**kwargs)
    elif auto_regress_type == "set_generative_4":
        return SetGenerate4(**kwargs)
    elif auto_regress_type == "set_generative_5":
        return SetGenerate5(**kwargs)
    elif auto_regress_type == "set_generative_6":
        return SetGenerate6(**kwargs)
    elif auto_regress_type == "set_generative_7":
        return SetGenerate7(**kwargs)
    elif auto_regress_type == "gmm_set_generative":
        return GMMSetGenerate(**kwargs)
    elif auto_regress_type == "residual":
        return Residual(**kwargs)
    else:
        raise ValueError(f"Unknown AutoRegressType: {auto_regress_type}")


class RNN(nn.Module):
    def setup(self):
        self.ls_1   = nn.LSTMCell()
        self.ls_2   = nn.LSTMCell()

    def __call__(self, r, generate_num, mask=None):
        key = self.make_rng("sample")
        key, sample_key = jax.random.split(key,2)
        r_front_shape, r_tail_shape = r.shape[:-2], r.shape[-1]
        generate_initial = jax.random.normal(key, (*r_front_shape, generate_num, r_tail_shape))
        r_1 = F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2)), F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2)) + jax.random.normal(sample_key, (*r_front_shape, r_tail_shape))
        r_2 = F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2)), F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2)) + jax.random.normal(key, (*r_front_shape, r_tail_shape))
        r_list = []
        for i in range(generate_num):
            r_1, r_ = self.ls_1(r_1, generate_initial[..., i, :])
            r_2, r_ = self.ls_2(r_2, r_)
            r_list.append(r_)
        r_list = jnp.stack(r_list, axis=-2)
        return r_list


class SetGenerate(nn.Module):
    dim_out: int = 128
    dim_hidden: int = 128
    num_heads: int = 8

    def setup(self):
        self.isab = ISAB(dim_out=self.dim_out, dim_hidden=self.dim_hidden, num_heads=self.num_heads)

    def __call__(self, r, generate_num, mask=None):
        key = self.make_rng("sample")
        key, sample_key = jax.random.split(key,2)
        r_front_shape, r_tail_shape = r.shape[:-2], r.shape[-1]
        # r_mean = jnp.stack([F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2))] * generate_num, axis=-2)
        # generate_initial = r_mean + jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        generate_initial = jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        generate_mask = jnp.ones((r.shape[0], generate_num), dtype=jnp.bool_)
        out = self.isab(r, generate_initial, mask, generate_mask)
        return out

class SetGenerate2(nn.Module):
    dim_out: int = 128
    dim_hidden: int = 128
    num_heads: int = 8

    def setup(self):
        self.isab = ISAB(dim_out=self.dim_out, dim_hidden=self.dim_hidden, num_heads=self.num_heads)

    def __call__(self, r, generate_num, mask=None):
        key = self.make_rng("sample")
        key, sample_key = jax.random.split(key,2)
        r_front_shape, r_tail_shape = r.shape[:-2], r.shape[-1]
        # r_mean = jnp.stack([F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2))] * generate_num, axis=-2)
        # generate_initial = r_mean + jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        generate_initial = jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        generate_mask = jnp.ones((r.shape[0], generate_num), dtype=jnp.bool_)
        out = self.isab(r, generate_initial, mask, generate_mask)
        x_generated_ctx = 4*(nn.sigmoid(jnp.split(out,2,axis=-1)[0])-1/2)
        y_generated_ctx = jnp.split(out, 2, axis=-1)[1]
        out = jnp.concatenate([x_generated_ctx,y_generated_ctx],axis=-1)
        return out


class SetGenerate3(nn.Module):
    dim_out: int = 128
    dim_hidden: int = 128
    num_heads: int = 8

    def setup(self):
        self.isab = ISAB(dim_out=self.dim_out, dim_hidden=self.dim_hidden, num_heads=self.num_heads)

    def __call__(self, r, generate_num, mask=None):
        key = self.make_rng("sample")
        key, sample_key = jax.random.split(key,2)
        r_front_shape, r_tail_shape = r.shape[:-2], r.shape[-1]
        # r_mean = jnp.stack([F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2))] * generate_num, axis=-2)
        # generate_initial = r_mean + jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        generate_initial = jax.random.t(sample_key, 2.0, (*r_front_shape, generate_num, r_tail_shape))
        generate_mask = jnp.ones((r.shape[0], generate_num), dtype=jnp.bool_)
        out = self.isab(r, generate_initial, mask, generate_mask)
        x_generated_ctx = 4*(nn.sigmoid(jnp.split(out,2,axis=-1)[0])-1/2)
        y_generated_ctx = jnp.split(out, 2, axis=-1)[1]
        out = jnp.concatenate([x_generated_ctx,y_generated_ctx],axis=-1)
        return out


class SetGenerate4(nn.Module):
    dim_out: int = 128
    dim_hidden: int = 128
    num_heads: int = 8

    def setup(self):
        self.isab = ISAB(dim_out=self.dim_out, dim_hidden=self.dim_hidden, num_heads=self.num_heads)

    def __call__(self, r, generate_num, mask=None):
        key = self.make_rng("sample")
        key, sample_key = jax.random.split(key,2)
        r_front_shape, r_tail_shape = r.shape[:-2], r.shape[-1]
        # r_mean = jnp.stack([F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2))] * generate_num, axis=-2)
        # generate_initial = r_mean + jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        generate_initial = jax.random.t(sample_key, 2.0, (*r_front_shape, generate_num, r_tail_shape))
        generate_mask = jnp.ones((r.shape[0], generate_num), dtype=jnp.bool_)
        out = self.isab(r, generate_initial, mask, generate_mask)
        # x_generated_ctx = 4*(nn.sigmoid(jnp.split(out,2,axis=-1)[0])-1/2)
        x_generated_ctx = 4*(1/(1+jnp.exp(-1/4*jnp.split(out,2,axis=-1)[0]))-1/2)
        y_generated_ctx = jnp.split(out, 2, axis=-1)[1]
        out = jnp.concatenate([x_generated_ctx,y_generated_ctx],axis=-1)
        return out


class SetGenerate5(nn.Module):
    dim_out: int = 128
    dim_hidden: int = 128
    num_heads: int = 8

    def setup(self):
        self.isab = ISAB(dim_out=self.dim_out, dim_hidden=self.dim_hidden, num_heads=self.num_heads)

    def __call__(self, r, generate_num, mask=None):
        key = self.make_rng("sample")
        key, sample_key = jax.random.split(key,2)
        r_front_shape, r_tail_shape = r.shape[:-2], r.shape[-1]
        # r_mean = jnp.stack([F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2))] * generate_num, axis=-2)
        # generate_initial = r_mean + jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        generate_initial = jax.random.normal(sample_key, (*r_front_shape, generate_num, 128))
        generate_mask = jnp.ones((r.shape[0], generate_num), dtype=jnp.bool_)
        out = self.isab(r, generate_initial, mask, generate_mask)
        # x_generated_ctx = 4*(nn.sigmoid(jnp.split(out,2,axis=-1)[0])-1/2)
        x_generated_ctx = 4*(1/(1+jnp.exp(-1/4*jnp.split(out,2,axis=-1)[0]))-1/2)
        y_generated_ctx = jnp.split(out, 2, axis=-1)[1]
        out = jnp.concatenate([x_generated_ctx,y_generated_ctx],axis=-1)
        return out


class SetGenerate6(nn.Module):
    dim_out: int = 128
    dim_hidden: int = 128
    num_heads: int = 8

    def setup(self):
        self.isab = ISAB2(dim_out=self.dim_out, dim_hidden=self.dim_hidden, num_heads=self.num_heads)

    def __call__(self, r, generate_num, mask=None):
        key = self.make_rng("sample")
        key, sample_key = jax.random.split(key,2)
        r_front_shape, r_tail_shape = r.shape[:-2], r.shape[-1]
        # r_mean = jnp.stack([F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2))] * generate_num, axis=-2)
        # generate_initial = r_mean + jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        generate_initial = jax.random.normal(sample_key, (*r_front_shape, generate_num, 128))
        generate_mask = jnp.ones((r.shape[0], generate_num), dtype=jnp.bool_)
        out = self.isab(r, generate_initial, mask, generate_mask)
        # x_generated_ctx = 4*(nn.sigmoid(jnp.split(out,2,axis=-1)[0])-1/2)
        x_generated_ctx = 4*(1/(1+jnp.exp(-1/4*jnp.split(out,2,axis=-1)[0]))-1/2)
        y_generated_ctx = jnp.split(out, 2, axis=-1)[1]
        out = jnp.concatenate([x_generated_ctx,y_generated_ctx],axis=-1)
        return out


class SetGenerate7(nn.Module):
    dim_out: int = 128
    dim_hidden: int = 128
    num_heads: int = 8

    def setup(self):
        self.isab = ISAB2(dim_out=self.dim_out, dim_hidden=self.dim_hidden, num_heads=self.num_heads)

    def __call__(self, r, generate_num, mask=None):
        key = self.make_rng("sample")
        key, sample_key = jax.random.split(key,2)
        r_front_shape, r_tail_shape = r.shape[:-2], r.shape[-1]
        # r_mean = jnp.stack([F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2))] * generate_num, axis=-2)
        # generate_initial = r_mean + jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        generate_initial = jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        generate_mask = jnp.ones((r.shape[0], generate_num), dtype=jnp.bool_)
        out = self.isab(r, generate_initial, mask, generate_mask)
        return out


class GMMSetGenerate(nn.Module):
    dim_out: int = 128
    dim_hidden: int = 128
    num_heads: int = 8
    num_gmm: int = 3

    def setup(self):
        self.isab = ISAB(dim_out=self.dim_out, dim_hidden=self.dim_hidden, num_heads=self.num_heads)
        self.means = self.param("means", jax.nn.initializers.normal(),(self.num_gmm,self.dim_out))

    def __call__(self, r, generate_num, mask=None):
        key = self.make_rng("sample")
        key, sample_key = jax.random.split(key,2)
        r_front_shape, r_tail_shape = r.shape[:-2], r.shape[-1]
        r_mean = jnp.stack([F.masked_mean(r, axis=-2, mask=mask, mask_axis=(0, -2))] * generate_num, axis=-2)
        # generate_initial = r_mean + jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        gmm_means = jax.random.choice(sample_key, jnp.arange(self.num_gmm), (*r_front_shape, generate_num,)).reshape(-1)
        gmm_means = self.means[gmm_means]
        generate_initial = jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape)).reshape(-1,r_tail_shape)
        generate_initial += gmm_means
        generate_initial = jnp.reshape(generate_initial, (*r_front_shape, generate_num, r_tail_shape))
        generate_mask = jnp.ones((r.shape[0], generate_num), dtype=jnp.bool_)
        out = self.isab(r, generate_initial, mask, generate_mask)
        return out


class MAB(nn.Module):
    dim_out: int = 128
    dim_hidden: int = 128
    num_heads: int = 8

    def setup(self):
        self.fc_q   = nn.Dense(features=self.dim_hidden)
        self.fc_k   = nn.Dense(features=self.dim_hidden)
        self.fc_v   = nn.Dense(features=self.dim_hidden)
        self.fc_out = nn.Dense(features=self.dim_hidden)
        self.fc_real_out = nn.Dense(features=self.dim_out)
        self.ln1    = nn.LayerNorm()
        self.ln2    = nn.LayerNorm()

    def scatter(self, x):
        return jnp.concatenate(jnp.split(x, self.num_heads, axis=-1), axis=0)

    def gather(self, x):
        return jnp.concatenate(jnp.split(x, self.num_heads, axis=0), axis=-1)

    def attend(self, q, k, v, mask=None):
        q_, k_, v_ = self.scatter(q), self.scatter(k), self.scatter(v)

        A_logits = q_ @ k_.swapaxes(-2, -1) / math.sqrt(self.dim_out)
        if mask is not None:
            mask = jnp.bool_(mask)
            mask = jnp.stack([mask] * q.shape[-2], axis=-2)
            mask = jnp.concatenate([mask] * self.num_heads, axis=0)
            if A_logits.ndim == 4:
                mask = jnp.expand_dims(mask, axis=1)
            if A_logits.ndim == 5:
                mask = jnp.expand_dims(mask, axis=(-2, -1,))
            # A = jax.nn.softmax(A_logits, where=mask, initial=0, axis=-1)  # TODO: Check below code can be replaced with this.
            A = jax.nn.softmax(jnp.where(mask, A_logits, -float('inf')), axis=-1)
            A = jnp.where(jnp.isnan(A), 0., A)
        else:
            A = jax.nn.softmax(A_logits, axis=-1)
        return self.gather(A @ v_)

    def __call__(self, q, v, mask=None):
        q, k, v = self.fc_q(q), self.fc_k(v), self.fc_v(v)
        out = self.ln1(q + self.attend(q, k, v, mask))
        out = self.ln2(out + nn.relu(self.fc_out(out)))
        out = self.fc_real_out(out)
        return out


class ISAB(nn.Module):
    dim_out: int = 128
    dim_hidden: int = 128
    num_heads: int = 8

    def setup(self):
        self.mab0 = MAB(dim_out=self.dim_out, dim_hidden=self.dim_hidden, num_heads=self.num_heads)
        self.mab1 = MAB(dim_out=self.dim_out, dim_hidden=self.dim_hidden, num_heads=self.num_heads)

    def __call__(self, context, generate_sample, mask_context=None, mask_generate=None):

        h = self.mab0(context, generate_sample, mask_generate)
        return self.mab1(generate_sample, h, mask_context)


class ISAB2(nn.Module):
    dim_out: int = 128
    dim_hidden: int = 128
    num_heads: int = 8

    def setup(self):
        self.mab0 = MAB(dim_out=self.dim_hidden, dim_hidden=self.dim_hidden, num_heads=self.num_heads)
        self.mab1 = MAB(dim_out=self.dim_out,    dim_hidden=self.dim_hidden, num_heads=self.num_heads)

    def __call__(self, context, generate_sample, mask_context=None, mask_generate=None):

        h = self.mab0(context, generate_sample, mask_generate)
        return self.mab1(generate_sample, h, mask_context)


class GenerateMultiheadAttention(nn.Module):
    dim_out: int = 128
    dim_hidden: int = 128
    num_heads: int = 8

    def setup(self):
        self.fc_q   = nn.Dense(features=self.dim_hidden)
        self.fc_k   = nn.Dense(features=self.dim_hidden)
        self.fc_v   = nn.Dense(features=self.dim_hidden)
        self.fc_out = nn.Dense(features=self.dim_hidden)
        self.fc_real_out = nn.Dense(features=self.dim_out)
        self.ln1    = nn.LayerNorm()
        self.ln2    = nn.LayerNorm()

    def scatter(self, x):
        return jnp.concatenate(jnp.split(x, self.num_heads, axis=-1), axis=0)

    def gather(self, x):
        return jnp.concatenate(jnp.split(x, self.num_heads, axis=0), axis=-1)

    def attend(self, q, k, v, mask=None):
        q_, k_, v_ = self.scatter(q), self.scatter(k), self.scatter(v)
        A_logits = q_ @ k_.swapaxes(-2, -1) / math.sqrt(self.dim_out)

        if mask is not None:
            mask = jnp.bool_(mask)
            mask = jnp.stack([mask] * q.shape[-2], axis=-2)
            mask = jnp.concatenate([mask] * self.num_heads, axis=0)
            if A_logits.ndim == 4:
                mask = jnp.expand_dims(mask, axis=1)
            # A = jax.nn.softmax(A_logits, where=mask, initial=0, axis=-1)  # TODO: Check below code can be replaced with this.
            A = jax.nn.softmax(jnp.where(mask, A_logits, -float('inf')), axis=-1)
            A = jnp.where(jnp.isnan(A), 0., A)
        else:
            A = jax.nn.softmax(A_logits, axis=-1)

        return self.gather(A @ v_)

    def __call__(self, q, k, v, mask=None):
        q, k, v = self.fc_q(q), self.fc_k(k), self.fc_v(v)
        out = self.ln1(q + self.attend(q, k, v, mask))
        out = self.ln2(out + nn.relu(self.fc_out(out)))
        out = self.fc_real_out(out)
        return out


class Residual(nn.Module):
    dim_out: int = 2
    num_heads: int = 8
    dim_hidden: int = 128

    def setup(self):
        self.isab = ISAB(dim_out=int(self.dim_out/2), num_heads=self.num_heads)
        self.gma = GenerateMultiheadAttention(dim_out=int(self.dim_out/2), dim_hidden=self.dim_hidden, num_heads=self.num_heads)

    def __call__(self, r, generate_num, mask=None):
        key = self.make_rng("sample")
        key, sample_key, noise_key = jax.random.split(key,3)
        r_front_shape, r_tail_shape = r.shape[:-2], r.shape[-1]
        generate_initial = jax.random.normal(sample_key, (*r_front_shape, generate_num, r_tail_shape))
        generate_mask = jnp.ones((r.shape[0], generate_num), dtype=jnp.bool_)
        generated_x = self.isab(r, generate_initial, mask, generate_mask)
        generated_x = 4 * (nn.sigmoid(generated_x)-0.5)
        generated_y_mean = 4*(nn.sigmoid(self.gma(generated_x, r[...,:1], r[...,1:], mask))-0.5)
        generate_noise = 0.25*jax.random.normal(noise_key, (*r_front_shape, generate_num, int(self.dim_out/2)))
        generated_y = generated_y_mean + generate_noise
        out = jnp.concatenate([generated_x, generated_y], axis=-1)
        return out
