# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""VAE model definitions."""

from flax import linen as nn
from jax import random
import jax.numpy as jnp
from functools import partial
import jax


# Group conic activation
@partial(jax.jit, static_argnames=['channel_axis','variant','eps','num_groups','project_axes','share_axis'])
def group_colu(input, channel_axis = -1, variant = "soft", eps = 1e-7, num_groups = 100,project_axes = False, share_axis = True):
    """project the input x onto the axes dimension"""
    """output dimension = S = axes + cone sections = [len=(G or 1)] + G * [len=(S-1)]"""
    if num_groups == 0: # trivial case
        return input
    num_channels = input.shape[channel_axis]
    if (share_axis and num_groups == num_channels - 1) or (not share_axis and num_groups * 2 == num_channels): # pointwise case
        return nn.silu(input) if variant == "soft" else nn.relu(input)
    group_size = (num_channels - 1) // num_groups + 1 if share_axis else num_channels // num_groups
        
    # y = axes, x = cone sections
    if share_axis:
        assert (num_channels - 1) % num_groups == 0, "Channel size must be a multiple of number of cones plus one"
        y, x = input.take(jnp.arange(1), axis=channel_axis), input.take(jnp.arange(1,num_channels), axis=channel_axis)
    else:
        assert num_channels % num_groups == 0, "Channel size must be a multiple of number of cones"
        y, x = input.take(jnp.arange(num_groups), axis=channel_axis), input.take(jnp.arange(num_groups,num_channels), axis=channel_axis)
        group_size = num_channels // num_groups # S = C / G

    assert channel_axis < 0, "channel_axis must be negative" # Comply with broadcasting on first dimensions
    x_old_shape = x.shape
    y_old_shape = y.shape
    x_shape = x.shape[:channel_axis] + (num_groups, group_size - 1) # NG(S-1)
    if share_axis:
        y_shape = y.shape[:channel_axis] + (1, 1) # N11
    else: 
        y_shape = y.shape[:channel_axis] + (num_groups, 1) # NG1
    if channel_axis < -1:
        x_shape += x.shape[(channel_axis+1):] # NGSHW if channel_axis = -3
        y_shape += y.shape[(channel_axis+1):] # NG1HW
    x = x.reshape(x_shape)
    y = y.reshape(y_shape)
    xn = jnp.linalg.norm(x,axis=channel_axis,keepdims=True) # NG1HW, norm

    if project_axes:
        assert not share_axis, "shuffle_axes is not compatible with share_axis"
        y0, y1 = y.take(jnp.arange(1), axis=channel_axis-1), y.take(jnp.arange(1,num_groups), axis=channel_axis-1) # N11HW, N(G-1)1HW
        yn = jnp.linalg.norm(y1,axis=channel_axis-1,keepdims=True) # N11HW
        ymask = y0 / (yn + eps) # N11HW
        ymask = nn.sigmoid(ymask-.5) if variant == "soft" else ymask.clip(0,1)
        y1 = ymask * y1 # N(G-1)1HW
        y = jnp.concatenate([y0,y1],axis=channel_axis-1)
    
    mask = y / (xn + eps) # NG1HW
    if variant == "softmax":
        mask = nn.softmax(mask, axis=channel_axis)
    elif variant == "softapprox":
        mask = nn.sigmoid(4 * mask - 2)
    elif variant == "soft":
        mask = nn.sigmoid(mask - .5)
    elif variant == "hard":
        mask = mask.clip(0,1)
    else:
        raise NotImplementedError("variant must be soft or hard.")

    x = mask * x # NGSHW
    x = x.reshape(x_old_shape)
    y = y.reshape(y_old_shape)
    output = jnp.concatenate([y,x],axis=channel_axis)

    return output


# # Multi-level Group conic activation # not quite working though
#     y, x = x.take(jnp.arange(dim_cone ** level), axis=-1), x.take(jnp.arange(level,dim_cone), axis=-1)
#     for level in range(depth):
#         if num_channels % dim_cone ** level == 0:
#             break
#         # axes, cone sections, rest
#         section_dim_from, section_dim_to = dim_cone ** level,dim_cone ** (level+1)
#         z, x = x.take(jnp.arange(section_dim_from, section_dim_to), axis=-1), x.take(jnp.arange(section_dim_to, num_channels), axis=-1)
#         zn = jnp.linalg.norm(x,axis=channel_axis,keepdims=True) # NG1HW, per-group norm, or the S dimension
#         mask = y / (zn + eps) # NG1HW


# def group_colu(x):
#     """spherical normalization"""
#     xn = jnp.linalg.norm(x,axis=-1,keepdims=True) # NG1HW, per-group norm, or the S dimension
#     return x / xn

class Encoder(nn.Module):
  """VAE Encoder."""

  latents: int = 500
  model_name: str = "cifar10"
  act_fn: str = "silu"
  num_groups: int = 100
  variant: str = "soft"
  share_axis: bool = True

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.latents, name='fc1')(x)
    x = group_colu(x,num_groups=self.num_groups,variant=self.variant,share_axis=self.share_axis) if self.act_fn == "colu" else nn.silu(x)
    mean_x = nn.Dense(20, name='fc2_mean')(x)
    logvar_x = nn.Dense(20, name='fc2_logvar')(x)
    return mean_x, logvar_x


class Decoder(nn.Module):
  """VAE Decoder."""
  latents: int = 500
  model_name: str = "cifar10"
  act_fn: str = "silu"
  num_groups: int = 100
  variant: str = "soft"
  share_axis: bool = True

  @nn.compact
  def __call__(self, z):
    z = nn.Dense(self.latents, name='fc1')(z)
    z = group_colu(z,num_groups=self.num_groups,variant=self.variant,share_axis=self.share_axis) if self.act_fn == "colu" else nn.silu(z)
    C_out = 3072 if self.model_name == 'cifar10' else 784
    z = nn.Dense(C_out, name='fc2')(z)
    return z


class VAE(nn.Module):
  """Full VAE model."""

  latents: int = 500
  model_name: str = "cifar10"
  act_fn: str = "silu"
  num_groups: int = 100
  variant: str = "soft"
  share_axis: bool = True

  def setup(self):
    self.encoder = Encoder(self.latents,self.model_name,self.act_fn,num_groups=self.num_groups,variant=self.variant,share_axis=self.share_axis)
    self.decoder = Decoder(self.latents,self.model_name,self.act_fn,num_groups=self.num_groups,variant=self.variant,share_axis=self.share_axis)

  def __call__(self, x, z_rng):
    mean, logvar = self.encoder(x)
    z = reparameterize(z_rng, mean, logvar)
    recon_x = self.decoder(z)
    return recon_x, mean, logvar

  def generate(self, z):
    return nn.sigmoid(self.decoder(z))


def reparameterize(rng, mean, logvar):
  std = jnp.exp(0.5 * logvar)
  eps = random.normal(rng, logvar.shape)
  return mean + eps * std


def model(latents,model_name="binarized_mnist",act_fn="silu",num_groups=100,variant="soft",share_axis=True):
  return VAE(latents=latents,model_name=model_name,act_fn=act_fn,num_groups=num_groups,variant=variant,share_axis=share_axis)
