
import functools
import jax
import jax.numpy as jnp
import numpy as np
from jax.nn.initializers import normal
from koopman.networks.dynamics_models.ops import discretize, compute_measurement, increasing_im_init, \
    log_step_init, random_im_init, s4d_kernel_zoh
from flax import linen as nn
from typing import Callable
from jax import jit
import math

from koopman.networks.common import MLP



############## Reward model #########################

class BatchedRewardModel(nn.Module):
    """A batched reward model that takes in a batch of states and returns a batch of rewards."""
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    @nn.compact
    def __call__(self, states, actions):
        state_action = jnp.concatenate([states, actions], axis=-1)
        return MLP([64, 64, 1], activations=self.activations)(state_action)





############## Diagonal KOOPMAN DYNAMICS MODEL (Fast) #########################



@jit
def koopman_forward(Vand_K, L, actions_emb, start_state_rep):
    action_emb_dim = actions_emb.shape[-1]
    state_dim = start_state_rep.shape[-1]
    actions_emb_complex = actions_emb[:, :, :action_emb_dim//2] \
                          + 1j * actions_emb[:, :, action_emb_dim//2:]
    start_state_rep_complex = start_state_rep[:, :state_dim // 2] \
                              + 1j * start_state_rep[:, state_dim // 2:]

    predicted_state_rep_complex = jax.vmap(
        compute_measurement,
        in_axes=(None, 0, 0),
        out_axes=0
    )(Vand_K, actions_emb_complex @ L, start_state_rep_complex)

    predicted_state_rep = jnp.concatenate([
        jnp.real(predicted_state_rep_complex),
        jnp.imag(predicted_state_rep_complex)
    ], -1)
    return predicted_state_rep

@jit
def koopman_forward_single(K, L, action_emb, start_state_rep):
    action_emb_dim = action_emb.shape[-1]
    state_dim = start_state_rep.shape[-1]
    action_emb_complex = action_emb[:, :action_emb_dim//2] \
                         + 1j * action_emb[:, action_emb_dim//2:]
    start_state_rep_complex = start_state_rep[:, :state_dim // 2] \
                              + 1j * start_state_rep[:, state_dim // 2:]
    predicted_state_rep_complex = K[None, :] * start_state_rep_complex + action_emb_complex @ L
    predicted_state_rep = jnp.concatenate([
        jnp.real(predicted_state_rep_complex),
        jnp.imag(predicted_state_rep_complex)
    ], -1)
    return predicted_state_rep


class BatchedDiagonalKoopmanDynamicsModel(nn.Module):
    state_dim: int
    action_emb_dim: int
    real_init_type: str = 'constant'
    real_init_value: float = -0.5
    im_init_type: str = 'increasing_freq'
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    def init_koopman_imaginary_params(self):
        if self.im_init_type == 'increasing_freq':
            self.K_im = self.param("K_im", increasing_im_init(), (self.state_dim // 4,))
        elif self.im_init_type == 'random':
            self.K_im = self.param("K_im", random_im_init(), (self.state_dim // 4,))
        else:
            raise Exception('Imaginary part initialization method not implemented')

    def init_koopman_real_params(self):
        if self.real_init_type == 'constant':
            self.K_real = self.real_init_value
        elif self.real_init_type == 'learnable':
            self.K_r = self.param("K_real", nn.initializers.ones, (self.state_dim // 4,)) * self.real_init_value
            self.K_real = jnp.clip(self.K_r, -0.4, -0.1)
        else:
            raise Exception('Real part initialization method not implemented')

    def setup(self):
        # model parameters
        self.init_koopman_real_params()
        self.init_koopman_imaginary_params()
        self.action_encoder = MLP([128, self.action_emb_dim * 2], activations=self.activations)
        self.L = self.param("L", normal(0.1), (self.action_emb_dim, self.state_dim // 2, 2))

        # Step parameter
        self.step = jnp.exp(self.param("log_step", log_step_init(), (1,)))

        self.K_complex = jnp.concatenate(
            [self.K_real + 1j * self.K_im,
            self.K_real - 1j * self.K_im],
            axis=-1
        )
        self.L_complex = self.L[:, :, 0] + 1j * self.L[:, :, 1]
        self.K_dis, self.L_dis = discretize(self.K_complex, self.L_complex, self.step)


    def __call__(self, actions, start_state_rep):
        if len(actions.shape) == 2:
            return koopman_forward_single(
                self.K_dis, self.L_dis,
                self.action_encoder(actions), start_state_rep
            )
        else:
            Vand_K = jnp.vander(self.K_dis, actions.shape[1] + 1, increasing=True)
            return koopman_forward(
                Vand_K, self.L_dis,
                self.action_encoder(actions), start_state_rep
            )


#######################################################################



############## DENSE KOOPMAN DYNAMICS MODEL (Not so good gradient behaviour) #########################

class BatchedDenseKoopmanDynamicsModel(nn.Module):
    state_dim: int
    action_emb_dim: int
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    def setup(self):
        self.action_encoder = MLP([256, self.action_emb_dim], activations=self.activations)
        self.K = self.param("K", normal(stddev=0.5), (self.state_dim, self.state_dim))
        self.L = self.param("L", normal(stddev=0.5), (self.action_emb_dim, self.state_dim))

    @nn.compact
    def __call__(self, actions, start_state_rep):
        output = []
        cur_state = start_state_rep
        for i in range(actions.shape[1]):
            next_state = cur_state @ self.K + self.action_encoder(actions[:, i]) @ self.L
            output.append(jnp.expand_dims(next_state, axis=1))
            cur_state = next_state
        output = jnp.concatenate(output, axis=1)
        return output

#######################################################################







#################### REGULAR DYNAMICS MODEL ###########################
class RegularDynamicsModel(nn.Module):
    state_dim: int
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    def setup(self):
        self.dynamics_block = MLP([256, self.state_dim], activations=self.activations)

    def __call__(self, actions, start_state_rep):
        '''
        input shapes:
                    actions: n * a_dim
                    start_states_rep: d
        output shapes:
                    pred_states: n * d
        '''
        if len(actions.shape) == 1:
            inputs = jnp.concatenate([start_state_rep, actions], -1)
            return self.dynamics_block(inputs)
        output = []
        cur_state = start_state_rep
        for i in range(actions.shape[0]):
            inputs = jnp.concatenate([cur_state, actions[i]], -1)
            next_state = self.dynamics_block(inputs)
            output.append(next_state)
            cur_state = next_state
        output = jnp.stack(output, 0)
        return output

class BatchedRegularDynamicsModel(nn.Module):
    state_dim: int
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    @nn.compact
    def __call__(self, actions, start_state_rep):
        vmap_dynamics_model = nn.vmap(
                                    RegularDynamicsModel,
                                    in_axes=0,
                                    out_axes=0,
                                    variable_axes={"params": None},
                                    split_rngs={"params": False},
                                )
        return vmap_dynamics_model(self.state_dim, self.activations)(
            actions, start_state_rep
        )

class EnsembleBatchedRegularDynamicsModel(nn.Module):
    state_dim: int
    num_ensemble: int
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    @nn.compact
    def __call__(self, actions, start_state_rep):
        ensemble_dynamics_model = nn.vmap(
            BatchedRegularDynamicsModel,
            in_axes=None,
            out_axes=0,
            variable_axes={"params": True},
            split_rngs={"params": True},
            axis_size=self.num_ensemble
        )
        return ensemble_dynamics_model(self.state_dim, self.activations)(
            actions, start_state_rep
        )


#######################################################################





############### Transformer Dynamics model #############################


    
class PositionalEncoding(nn.Module):
    d_model: int
    max_len: int = 5000

    @nn.compact
    def __call__(self, x):
        pos = jnp.arange(self.max_len)[:, jnp.newaxis]
        div_term = jnp.exp(jnp.arange(0, self.d_model, 2) * -jnp.log(10000.0) / self.d_model)
        sin_pos = jnp.sin(pos * div_term)
        cos_pos = jnp.cos(pos * div_term)
        pos_encoding = jnp.concatenate([sin_pos, cos_pos], axis=-1)
        pos_encoding = pos_encoding[:x.shape[1], :]
        return x + pos_encoding


    
class MultiheadAttention(nn.Module):
    num_heads: int
    head_dim: int = 64

    @nn.compact
    def __call__(self, inputs, mask=None):
        qkv = nn.Dense(3 * self.num_heads * self.head_dim)(inputs)
        qkv = jnp.reshape(qkv, (inputs.shape[0], inputs.shape[1], 3, self.num_heads, self.head_dim))
        qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = qkv

        dot_product = jnp.einsum('bhid,bhjd->bhij', q, k)
        scaled_dot_product = dot_product / jnp.sqrt(self.head_dim)

        if mask is not None:
            mask = jnp.expand_dims(mask, axis=(0, 1))
            scaled_dot_product += (1.0 - mask) * -1e9

        attention_weights = nn.softmax(scaled_dot_product, axis=-1)
        attention_output = jnp.einsum('bhij,bhjd->bhid', attention_weights, v)

        attention_output = jnp.transpose(attention_output, (1, 0, 2, 3))
        attention_output = jnp.reshape(attention_output, (inputs.shape[0], inputs.shape[1], self.num_heads * self.head_dim))

        output = nn.Dense(inputs.shape[-1])(attention_output)

        return output
    

class TransformerEncoderBlock(nn.Module):
    input_dim : int  # Input dimension is needed here since it is equal to the output dimension (residual connection)
    num_heads : int = 8
    dim_feedforward : int = 256
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    def setup(self):
        # Attention layer
        self.self_attn = MultiheadAttention(num_heads=self.num_heads)
        # Two-layer MLP
        self.linear = MLP([self.dim_feedforward, self.input_dim], activations=self.activations)
        
        # Layers to apply in between the main layers
        self.norm1 = nn.LayerNorm()
        self.norm2 = nn.LayerNorm()

    def __call__(self, x, mask=None):
        # Attention part
        attn_out = self.self_attn(x, mask=mask)
        x = x + attn_out
        x = self.norm1(x)

        # MLP part
        x = x + self.linear(x)
        x = self.norm2(x)

        return x
    
class BatchedTransformerDynamicsModel(nn.Module):
    state_dim : int 
    action_emb_dim: int
    num_heads : int = 4
    use_masked_token: int = 1
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    def setup(self):     
        self.transformer_encoder_block = TransformerEncoderBlock(
        self.state_dim + self.action_emb_dim, self.num_heads
        )
        self.positional_encoding = PositionalEncoding(
            self.state_dim + self.action_emb_dim
        )
        self.action_encoder = MLP([64, self.action_emb_dim], activations=self.activations)
        self.final_fc = nn.Dense(self.state_dim)

    def __call__(self, actions, start_state_rep):
        '''
        input shapes:
                    actions: batch_size x seq_len x a_dim
                    start_states_rep: d
        output shapes:
                    pred_states: batch_size x seq_len x d
        '''
        seq_len = actions.shape[1]

            # encode actions
        action_tokens = self.action_encoder(actions)

        if self.use_masked_token:
            # define sequence of state tokens with the first token start_state_rep and the rest masked tokens
            state_tokens = jnp.concatenate(
                (jnp.expand_dims(start_state_rep, axis=1),
                jnp.zeros((actions.shape[0], seq_len-1, self.state_dim))
                ), axis=1)
        else:
            # define sequence of state tokens where start_state_rep is repeated seq_len times
            state_tokens = jnp.expand_dims(start_state_rep, axis=1).repeat(seq_len, axis=1)


        # concatenate actions and start state representation to get a sequence of inputs
        concat_seq = jnp.concatenate(
            (action_tokens, state_tokens
            ), axis=-1)       
        
        # add positional encoding
        concat_seq = self.positional_encoding(concat_seq)

        # run causal transformer encoder block by defining appropriate attention mask
        mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.int32))

        outputs = self.transformer_encoder_block(concat_seq, mask=mask)

        # transform the output to get the predicted states
        pred_states = self.final_fc(outputs)    

        return pred_states


############### GRU Dynamics model #############################

class GRUEncoder(nn.Module):
    hidden_dim: int

    def setup(self):
        self.gru = nn.GRUCell(name="gru_cell")

    def __call__(self, hidden_state, inputs):
        return self.gru(hidden_state, inputs)



class BatchedGRUDynamicsModel(nn.Module):
    state_dim : int
    action_emb_dim: int
    hidden_dim: int = 64
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    def setup(self):
        self.action_encoder = MLP([64, self.action_emb_dim], activations=self.activations)
        self.gru = GRUEncoder(self.hidden_dim)
        self.final_fc = nn.Dense(self.state_dim)

    def __call__(self, actions, start_state_rep):
        # encode actions
        action_tokens = self.action_encoder(actions)

        # concatenate start_state_rep and zeros with all the actions
        state_tokens = jnp.concatenate(
            (jnp.expand_dims(start_state_rep, axis=1),
            jnp.zeros((actions.shape[0], actions.shape[1]-1, self.state_dim))
            ), axis=1)

        action_state_tokens = jnp.concatenate((action_tokens, state_tokens), axis=-1)
        
        # Initialize GRU hidden state with zeros
        hidden_state = jnp.zeros((actions.shape[0], self.hidden_dim))

        pred_states = []
        # GRU pass
        for t in range(actions.shape[1]):
            hidden_state, _ = self.gru(hidden_state, action_state_tokens[:, t, :])
            pred_state = self.final_fc(hidden_state)
            pred_states.append(pred_state)
        pred_states = jnp.stack(pred_states, axis=1)

        return pred_states




#################### DSSM Dynamics Model ###########################

def causal_convolution(u, K, nofft=False):
    if nofft:
        return convolve(u, K, mode="full")[: u.shape[0]]
    else:
        assert K.shape[0] == u.shape[0]
        ud = jnp.fft.rfft(jnp.pad(u, (0, K.shape[0])))
        Kd = jnp.fft.rfft(jnp.pad(K, (0, u.shape[0])))
        out = ud * Kd
        return jnp.fft.irfft(out)[: u.shape[0]]

class S4dLayer(nn.Module):
    N: int
    scaling: str = "lin"

    # Special parameters with multiplicative factor on lr and no weight decay (handled by main train script)
    lr = {
        "A_re": 0.1,
        "A_im": 0.1,
        "log_step": 0.1,
    }

    def setup(self):
        # Learned Parameters
        if self.scaling == "inv":
            self.A_re = self.param("A_re", nn.initializers.constant(-0.5), (self.N,))
            def arange_initializer(scale):
                return lambda key, shape: (shape[-1] / scale) * (shape[-1] / (2 * jnp.arange(shape[-1]) + 1) - 1)
            self.A_im = self.param("A_im", arange_initializer(jnp.pi), (self.N,))
        elif self.scaling == "lin":
            self.A_re = self.param("A_re", nn.initializers.constant(-0.5), (self.N,))
            def arange_initializer(scale):
                return lambda key, shape: scale * jnp.ones(shape) * jnp.arange(shape[-1])
            self.A_im = self.param("A_im", arange_initializer(jnp.pi), (self.N,))
        self.A = jnp.clip(self.A_re, None, -1e-4) + 1j * self.A_im
        self.C = self.param("C", normal(stddev=.5 ** .5), (self.N, 2))
        self.C = self.C[..., 0] + 1j * self.C[..., 1]
        self.D = self.param("D", nn.initializers.ones, (1,))
        self.step = jnp.exp(
            self.param("log_step", log_step_init(), (1,))
        )
        

    def __call__(self, u):
        l_max = u.shape[0]
        K = s4d_kernel_zoh(self.C, self.A, l_max, self.step)
        return causal_convolution(u, K) + self.D * u

class DSSMCell(nn.Module):
    out_dim: int  = 128
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu


    def setup(self):
        self.dssm = S4dLayer(N=64)
        self.final_dense = nn.Dense(self.out_dim)


    def __call__(self, inputs):
        x = jax.vmap(self.dssm.__call__, in_axes=1, out_axes=1)(inputs)
        return self.activations(self.final_dense(x))        
           

class BatchedDSSMDynamicsModel(nn.Module):
    state_dim : int
    action_emb_dim: int
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    def setup(self):
        self.action_encoder = MLP([64, self.action_emb_dim], activations=self.activations)
        self.dssm_cell = DSSMCell(out_dim=self.state_dim//4)
        self.final_fc = nn.Dense(self.state_dim)
        

    def __call__(self, actions, start_state_rep):
        # encode actions
        action_tokens = self.action_encoder(actions)

        # concatenate start_state_rep and zeros with all the actions
        state_tokens = jnp.concatenate(
            (jnp.expand_dims(start_state_rep, axis=1),
            jnp.zeros((actions.shape[0], actions.shape[1]-1, self.state_dim))
            ), axis=1)

        action_state_tokens = jnp.concatenate((action_tokens, state_tokens), axis=-1)
        
        # Forward pass through DSSM cell
        pred_states = jax.vmap(self.dssm_cell.__call__,in_axes=0, out_axes=0)(action_state_tokens)

        # Apply final fc layer
        pred_states = self.final_fc(pred_states)

        return pred_states


