from jax import random, jit 
import jax.numpy as jnp
import flax
import optax
import flax.linen as nn
from math import prod
from typing import Sequence


# Wrapper for flax based nn for i for single parameterbase
class flaxNN:
  def __init__(self,fnn,batch):
    self.model = fnn
    self.variables = self.model.init(random.PRNGKey(0), batch)
    self.shapes = []
    self.variables_len = 0
    params = flax.traverse_util.flatten_dict(self.variables)
    for key,layer in params.items():
      self.shapes.append(layer.shape)
      self.variables_len += len(layer.flatten())
    self.app = jit(self.model.apply)
    
  def apply(self,batch):
    return self.app(self.variables, batch)
  
  def create_vars_from_params(self,params_update):
    assert len(params_update) == self.variables_len, "wrong number of parameter"
    it_1 = 0
    it_2 = 0
    params = flax.traverse_util.flatten_dict(self.variables)
    for key,layer in params.items():
      params[key] = params_update[it_1:it_1+prod(self.shapes[it_2])].reshape(self.shapes[it_2])
      it_1 += prod(self.shapes[it_2])
      it_2 += 1
    return flax.core.freeze(flax.traverse_util.unflatten_dict(params))
  def update_parameters(self,params_update):
    self.variables = self.create_vars_from_params(params_update)

# Wrapper for flax based nn for i for collection of parameters
class flaxNNX(flaxNN):
    def __init__(self, fnn, batch):
        super().__init__(fnn,batch)
        self.var_dict = {}
        self.applymaybefaster = jit(self.apply)
    def apply(self, idx, batch):
        assert idx in self.var_dict, "no variables with that index"
        return self.app(self.var_dict[idx], batch)
    def update_parameters(self,idx, params_update):
        self.var_dict[idx] = self.create_vars_from_params(params_update)

# From Flax basic examples:
# https://github.com/google/flax/blob/9fe386508e35c31ca3e1f151bf1a4a87d71486ae/README.md
class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    logits = nn.Dense(self.features[-1])(x)
    return logits