"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

from flax import linen as nn
import jax.numpy as jnp
import jax
from jax import jacfwd, grad, vmap
import pickle

import jax.scipy as jsp

from flax import linen as nn

import jax.numpy as jnp
from jax.nn.initializers import uniform as uniform_init
from jax import lax
from jax.random import uniform
from typing import Any, Callable, Sequence, Tuple
from functools import partial
import jax

def siren_init(weight_std, dtype):
    def init_fun(key, shape, dtype=dtype):
        if dtype == jnp.dtype(jnp.array([1j])):
            key1, key2 = jax.random.split(key)
            dtype = jnp.dtype(jnp.array([1j]).real)
            a = uniform(key1, shape, dtype) * 2 * weight_std - weight_std
            b = uniform(key2, shape, dtype) * 2 * weight_std - weight_std
            return a + 1j*b
        else:
            return uniform(key, shape, dtype) * 2 * weight_std - weight_std

    return init_fun


class Sine(nn.Module):
    w0: float = 1.0
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, inputs):
        inputs = jnp.asarray(inputs, self.dtype)
        return jnp.sin(self.w0 * inputs)


class SirenLayer(nn.Module):
    features: int = 32
    w0: float = 1.0
    c: float = 6.0
    is_first: bool = False
    use_bias: bool = True
    act: Callable = jnp.sin
    precision: Any = None
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, inputs):
        inputs = jnp.asarray(inputs, self.dtype)
        input_dim = inputs.shape[-1]

        # Linear projection with init proposed in SIREN paper
        weight_std = (
            (1 / input_dim) if self.is_first else jnp.sqrt(self.c / input_dim) / self.w0
        )

        kernel = self.param(
            "kernel", siren_init(weight_std, self.dtype), (input_dim, self.features)
        )
        kernel = jnp.asarray(kernel, self.dtype)

        y = lax.dot_general(
            inputs,
            kernel,
            (((inputs.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )

        if self.use_bias:
            bias = self.param("bias", nn.initializers.zeros_init(), (self.features,))
            bias = jnp.asarray(bias, self.dtype)
            y = y + bias

        return self.act(self.w0 * y)
class Siren(nn.Module):
    hidden_dim: int = 256
    output_dim: int = 3
    num_layers: int = 5
    w0: float = 1.0
    w0_first_layer: float = 1.0
    use_bias: bool = True
    final_activation: Callable = lambda x: x  # Identity
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, inputs):
        x = jnp.asarray(inputs, self.dtype)
        
        for layernum in range(self.num_layers - 1):
            is_first = layernum == 0

            x = SirenLayer(
                features=self.hidden_dim,
                w0=self.w0_first_layer if is_first else self.w0,
                is_first=is_first,
                use_bias=self.use_bias,
            )(x)

        # Last layer, with different activation function
        x = SirenLayer(
            features=self.output_dim,
            w0=self.w0,
            is_first=False,
            use_bias=self.use_bias,
            act=self.final_activation,
        )(x)

        return x


        



# class Siren(MLP):
#     layers: list[eqx.Module]
#     input_scale: float

#     def __init__(self,
#                  in_features: int,
#                  hidden_features: int,
#                  hidden_layers: int,
#                  out_features: int,
#                  key: jax.random.PRNGKey,
#                  first_omega_0: float = 30,
#                  hidden_omega_0: float = 30,
#                  input_scale: float = 1,
#                  **kwargs):
#         keys = jax.random.split(key, hidden_layers + 2)
#         self.input_scale = input_scale

#         # Section 3.2
#         # For [-1, 1], first_omega_0 span it to [-30, 30]
#         # Here to scale periods back
#         first_omega_0 = first_omega_0 / input_scale

#         self.layers = [
#             SineLayer(in_features,
#                       hidden_features,
#                       keys[0],
#                       is_first=True,
#                       omega_0=first_omega_0)
#         ] + [
#             SineLayer(hidden_features,
#                       hidden_features,
#                       keys[i + 1],
#                       omega_0=hidden_omega_0) for i in range(hidden_layers)
#         ] + [Linear(hidden_features, out_features, keys[-1], False)]

#     def single_call(self, x, z):
#         x = jnp.hstack([self.input_scale * x, z])
#         for i in range(len(self.layers)):
#             x = self.layers[i](x)
#         return x
