# -*- coding: utf-8 -*-
"""
Created on Mon Apr 29 21:45:18 2024

@author: javie
"""
# NN modules
import logging
import jax.numpy as jnp
from jax import vmap
import objax
from emlp.nn import Linear as emlpLinear

class SGD(objax.Module):
    """
    SGD optimization method
    """
    def __init__(self, variables: objax.VarCollection):
        self.refs = objax.ModuleList(objax.TrainRef(x)
                                     for x in variables.subset(objax.TrainVar))

    def __call__(self, lr: float, gradients: list):
        for v, g in zip(self.refs, gradients):
            v.value -= lr * g
            
class FA_Model(objax.Module):
    """ Feature-Averaged version of a model"""
    def __init__(self, model, G, repin, repout):
        self.model = model
        self.extend = lambda x : jnp.stack([x, vmap(lambda x: repin.rho_dense(G.discrete_generators[0])@x)(x)])
        self.model_ext = vmap(lambda x: model(x))
        self.average = lambda y: jnp.stack([y[0], vmap(lambda x: (repout.rho_dense(G.discrete_generators[0]).T)@x)(y[1])]).mean(axis=0)
        self.N = model.N

    def __call__(self, x, training=False): # OJO no estamos usando el paràmetro training
        """
        Generates the output int_G g^{-1}.f(g.x) dg
        """
        logging.debug(f"linear in shape: {x.shape}")
        out = self.average(self.model_ext(self.extend(x)))
        logging.debug(f"linear out shape:{out.shape}")
        return out
    
    def get_particles(self):
      return self.model.get_particles()

    def get_param_shape(self):
      return self.model.get_param_shape()

    def set_particles(self, *args):
      self.model.set_particles(*args) 
    
###### CUSTOM MODULES #####
    
# class AlternativeLinear(objax.nn.Linear):
#     """ Basic equivariant Linear layer from repin to repout."""
#     def __init__(self, repin, repout, use_bias =True, w_init = None):
#         nin,nout = repin.size(),repout.size()
#         super().__init__(nin,nout)
#         self.nin, self.nout = nin, nout
#         self.rep_W = rep_W = repout*repin.T
#         rep_bias = repout

#         self.base_W = rep_W.equivariant_basis()
#         dim_W = self.base_W.shape[-1]
#         self.Pw = rep_W.equivariant_projector()

#         if use_bias:
#           self.base_bias = rep_bias.equivariant_basis()
#           dim_bias = self.base_bias.shape[-1]
#           self.Pb = rep_bias.equivariant_projector()

#         self.b = objax.variable.TrainVar(objax.random.uniform((dim_bias,))/jnp.sqrt(dim_bias)) if use_bias else None
#         self.w = objax.variable.TrainVar(objax.random.uniform((dim_W,))/jnp.sqrt(dim_W)) if w_init is None else w_init((dim_W,)) #orthogonal((nout, nin)))

#         logging.info(f"Linear W components:{rep_W.size()} rep:{rep_W}")

#     def __call__(self, x): # (cin) -> (cout)
#         logging.debug(f"linear in shape: {x.shape}")
#         W = (self.base_W@self.w.value).reshape((self.nout, self.nin))
#         out = x@W.T
#         if self.b is not None:
#           b = self.base_bias@self.b.value
#           out += b
#         logging.debug(f"linear out shape:{out.shape}")
#         return out

class AlternativeLinear(objax.nn.Linear):
    """ 
    Basic equivariant Linear layer from repin to repout.
    
    This alternative version was created to ensure a truly "lower dimensional"
    equivariant layer (instead of projecting a "same dimension" matrix onto the
    space, the base is used to pass from lower dimension to the ambient space dimension).
    """
    def __init__(self, repin, repout, use_bias =True, w_init = None, transpose = False):
        self.transpose = transpose
        if self.transpose:
          repin, repout = repout, repin
        nin,nout = repin.size(),repout.size()
        super().__init__(nin,nout)
        self.nin, self.nout = nin, nout
        self.rep_W = rep_W = repout*repin.T
        rep_bias = repout

        self.base_W = rep_W.equivariant_basis()
        dim_W = self.base_W.shape[-1]
        self.Pw = rep_W.equivariant_projector()

        if use_bias:
          self.base_bias = rep_bias.equivariant_basis()
          dim_bias = self.base_bias.shape[-1]
          self.Pb = rep_bias.equivariant_projector()

        self.b = objax.variable.TrainVar(objax.random.normal((dim_bias,))/jnp.sqrt(dim_bias)) if use_bias else None
        self.w = objax.variable.TrainVar(objax.random.normal((dim_W,))/jnp.sqrt(dim_W)) if w_init is None else w_init((dim_W,)) #orthogonal((nout, nin)))

        logging.info(f"Linear W components:{rep_W.size()} rep:{rep_W}")

    def __call__(self, x,training=False): # (cin) -> (cout)
        logging.debug(f"linear in shape: {x.shape}")
        W = (self.base_W@self.w.value).reshape((self.nout, self.nin))
        out = x@W.T if not self.transpose else x@W
        if self.b is not None:
          b = self.base_bias@self.b.value
          out += b
        logging.debug(f"linear out shape:{out.shape}")
        return out    
    
# class CustomLinear(objax.nn.Linear):
#     """ Basic Linear layer from repin to repout."""
#     def __init__(self, nin, nout, use_bias =True, w_init = None):
#         super().__init__(nin,nout)
#         self.nin, self.nout = nin, nout

#         self.b = objax.variable.TrainVar(objax.random.uniform((nout,))/jnp.sqrt(nout)) if use_bias else None
#         dim_w = nin*nout
#         self.w = objax.variable.TrainVar(objax.random.normal((dim_w,))/(2*nin)) if w_init is None else w_init((dim_w,))#orthogonal((nout, nin)))


#     def __call__(self, x): # (cin) -> (cout)
#         logging.debug(f"linear in shape: {x.shape}")
#         W = self.w.value.reshape((self.nout, self.nin))
#         out = x@W.T
#         if self.b is not None:
#           out += self.b.value
#         logging.debug(f"linear out shape:{out.shape}")
#         return out
    
class CustomLinear(objax.nn.Linear):
    """ 
    Basic Linear layer from nin to nout.
    
    This alternative version was created to ensure consistency with the application 
    of the equivariant linear layer. In particular, "projected parameters" act 
    equivariantly as expected.
    """
    def __init__(self, nin, nout, use_bias =True, w_init = None, transpose=False):
        self.transpose = transpose
        if self.transpose:
          nin, nout = nout, nin
        super().__init__(nin,nout)
        self.nin, self.nout = nin, nout

        self.b = objax.variable.TrainVar(objax.random.uniform((nout,))/jnp.sqrt(nout)) if use_bias else None
        dim_w = nin*nout
        self.w = objax.variable.TrainVar(objax.random.normal((dim_w,))/(4)) if w_init is None else w_init((dim_w,))#orthogonal((nout, nin)))


    def __call__(self, x,training=False): # (cin) -> (cout)
        logging.debug(f"linear in shape: {x.shape}")
        W = self.w.value.reshape((self.nout, self.nin))
        out = x@W.T if not self.transpose else x@W
        if self.b is not None:
          out += self.b.value
        logging.debug(f"linear out shape:{out.shape}")
        return out

    
class MFAverage(objax.Module):
    """ Basic operation to "average" the output of N neurons. """
    def __init__(self,N,output_dim):
        super().__init__()
        self.N = N
        self.output_dim = output_dim

    def __call__(self,x,training=False):
        return x.reshape((-1,self.N, self.output_dim)).mean(axis=-2)
    
class ShallowEMLPNoLinearOut(objax.Module):
    """ 
    Simplest possible shallow model (on its "equivariant" version; could be deprecated as it is the same as the MLP version). 
    
    It corresponds to the activation: sigma(W.x)
    """
    def __init__(self,N,rep_in,rep_out, activation=None, alternative=False, use_bias=True):
        super().__init__()
        self.linear = emlpLinear(rep_in,N*rep_out) if not alternative else AlternativeLinear(rep_in,N*rep_out, use_bias=use_bias)
        self.MFAverage = MFAverage(N,rep_out.size())
        self.activation = objax.functional.relu if activation is None else activation
        self.N = N
        #self.bilinear = BiLinear(gated(rep_out),gated(rep_out))
        #self.nonlinearity = GatedNonlinearity(rep_out)
    def __call__(self,x,training=False):
        #print(x.shape)
        x = self.linear(x)
        #print(x.shape)
        x = self.activation(x)
        #print(x.shape)
        x = self.MFAverage(x)
        #preact =self.bilinear(lin)+lin
        #print(x.shape)
        return x
    
class ShallowMLPNoLinearOut(objax.Module):
    """ 
    Simplest possible shallow model (on its "free" version)
                                     
    It corresponds to the activation: sigma(W.x)
    """
    def __init__(self,N,rep_in,rep_out, activation=None, alternative=False, use_bias = True, equivariant = False):
        super().__init__()
        self.equivariant = equivariant
        if self.equivariant:
          self.linear = emlpLinear(rep_in,N*rep_out) if not alternative else AlternativeLinear(rep_in,N*rep_out, use_bias=use_bias)
        else:
          self.linear = objax.nn.Linear(rep_in.size(),N*(rep_out.size()), use_bias=use_bias) if not alternative else CustomLinear(rep_in.size(),N*(rep_out.size()), use_bias=use_bias)
        self.MFAverage = MFAverage(N,rep_out.size())
        self.activation = objax.functional.relu if activation is None else activation
        self.N = N
        #self.bilinear = BiLinear(gated(rep_out),gated(rep_out))
        #self.nonlinearity = GatedNonlinearity(rep_out)
        
    def __call__(self,x,training=False):
        #print(x.shape)
        x = self.linear(x)
        #print(x.shape)
        x = self.activation(x)
        #print(x.shape)
        x = self.MFAverage(x)
        #preact =self.bilinear(lin)+lin
        #print(x.shape)
        return x

    def get_particles(self):
        if self.equivariant:
            return ((self.linear.base_W)@self.linear.w.value).reshape((self.N,-1)) 
        else:
            return self.linear.w.value.reshape(self.N,-1)

    def set_particles(self, particles):
      try:
        self.linear.w.assign(particles.reshape(-1))
      except:
        print("Please Ensure the particles have the correct shape")

    def get_param_shape(self):
      return self.linear.w.shape

    
    
class ShallowEMLP(objax.Module):
    """ 
    Simple shallow model (on its "equivariant" version; to be DEPRECATED) correponding to the 
    activation: a^T.sigma(W.x)
    """
    def __init__(self,N,rep_in,hidden_rep,rep_out, activation=None, alternative=False, use_bias=True):
        super().__init__()
        self.linear1 = emlpLinear(rep_in,N*hidden_rep) if not alternative else AlternativeLinear(rep_in,N*hidden_rep, use_bias=use_bias)
        self.linear2 = emlpLinear(N*hidden_rep,rep_out) if not alternative else AlternativeLinear(N*hidden_rep,rep_out, use_bias=use_bias, transpose=True)
        self.activation = objax.functional.relu if activation is None else activation
        self.N = N
        #self.bilinear = BiLinear(gated(rep_out),gated(rep_out))
        #self.nonlinearity = GatedNonlinearity(rep_out)
    def __call__(self,x,training=False):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        #preact =self.bilinear(lin)+lin
        return (1/self.N)*x
    
class ShallowMLP(objax.Module):
    """ 
    Simple shallow model (on its "free" version) correponding to the 
    activation: a^T.sigma(W.x)
    """
    def __init__(self,N,rep_in,hidden_rep,rep_out, activation=None, alternative=False, use_bias=True, equivariant = False):
        super().__init__()
        self.equivariant = equivariant
        if self.equivariant:
          self.linear1 = emlpLinear(rep_in,N*hidden_rep) if not alternative else AlternativeLinear(rep_in,N*hidden_rep, use_bias=use_bias)
          self.linear2 = emlpLinear(N*hidden_rep,rep_out) if not alternative else AlternativeLinear(N*hidden_rep,rep_out, use_bias=use_bias, transpose=True)
        else:
          self.linear1 = objax.nn.Linear(rep_in.size(),N*(hidden_rep.size()), use_bias=use_bias) if not alternative else CustomLinear(rep_in.size(),(N*hidden_rep).size(), use_bias=use_bias)
          self.linear2 = objax.nn.Linear(N*(hidden_rep.size(), rep_out.size()), use_bias=use_bias) if not alternative else CustomLinear((N*hidden_rep).size(),rep_out.size(), use_bias=use_bias, transpose=True)
        self.activation = objax.functional.relu if activation is None else activation
        self.N = N
        #self.bilinear = BiLinear(gated(rep_out),gated(rep_out))
        #self.nonlinearity = GatedNonlinearity(rep_out)
    
    def __call__(self,x,training=False):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        #preact =self.bilinear(lin)+lin
        return (1/self.N)*x

    def get_particles(self):
        if self.equivariant:
            return ((self.linear1.base_W)@self.linear1.w.value).reshape((self.N,-1)), ((self.linear2.base_W)@self.linear2.w.value).reshape((self.N,-1)) 
        else:
            return self.linear1.w.value.reshape(self.N,-1), self.linear2.w.value.reshape(self.N,-1)

    def set_particles(self, particles1=None, particles2=None):
      try:
        if not particles1 is None: 
          self.linear1.w.assign(particles1.reshape(-1))
        if not particles2 is None: 
          self.linear2.w.assign(particles2.reshape(-1))
      except:
        print("Please Ensure the particles have the correct shape")

    def get_param_shape(self):
      return self.linear1.w.shape, self.linear2.w.shape 



class ShallowNeuronKernel(objax.Module):
    """ Basic building block of Shallow NeuronKernel. """
    def __init__(self,samples,rep_in,rep_out, activation=None, stddev=1, seed = 0):
        super().__init__()
        self.nin, self.nout = rep_in.size(), rep_out.size()
        self.samples, self.stddev = samples, stddev
        self.mesh = objax.variable.StateVar(objax.random.normal((self.samples, self.nin), stddev=self.stddev, generator=objax.random.Generator(seed=seed)))
        self.activation = objax.functional.relu if activation is None else activation

    def __call__(self,theta1,theta2,training=False):
        W1 = theta1.reshape((-1, self.nout, self.nin))
        W2 = theta2.reshape((-1, self.nout, self.nin))

        out1 = self.activation(jnp.einsum("ij, jkl -> ikl", self.mesh,W1.T))
        out2 = self.activation(jnp.einsum("ij, jkl -> ikl", self.mesh,W2.T))

        out = (1/self.samples)*jnp.einsum("kji, kjl -> il", out1, out2)
        return out