import jax
import jax.numpy as jnp
import jax.random as jr
from jax import grad, vmap, jit
from jax import lax
from jax.tree_util import tree_map
import optax
import numpy as np
import functools
import random
import math


# Element Finder Task data generation ----------------------------------------------------


# generate sequences over all queries indices

def generate_sequences_and_inds(batch_size, seq_len, key, seq_bound=10):
    key, skey, sskey = jr.split(key, 3)
    sequences = jr.randint(skey, (batch_size, seq_len), -seq_bound, seq_bound+1)
    inds = jr.randint(sskey, (batch_size, 1), 0, seq_len)
    inputs = jnp.concatenate([inds, sequences], axis=1)

    outputs = jnp.choose(inputs[:, 0], inputs[:, 1:].T)

    return jnp.expand_dims(inputs, axis=-1), outputs

    

# generate sequences with queries belonging to a fixed set 

def generate_sequences_and_selected_inds(batch_size, idx_choices, seq_len, key, seq_bound=10):
    key, skey, sskey = jr.split(key, 3)
    sequences = jr.randint(skey, (batch_size, seq_len), -seq_bound, seq_bound+1)
    # inds = jr.randint(sskey, (batch_size, 1), idx_low, idx_high)
    inds = jr.choice(sskey, idx_choices, shape=(batch_size, 1))
    inputs = jnp.concatenate([inds, sequences], axis=1)

    outputs = jnp.choose(inputs[:, 0], inputs[:, 1:].T)

    return jnp.expand_dims(inputs, axis=-1), outputs



# generate sequences with a fixed query index 

def generate_sequences_and_fixed_query(batch_size, idx, seq_len, key, seq_bound=10, fixed_value=None):
    assert idx >= 0 and idx < seq_len
    key, skey, sskey = jr.split(key, 3)
    sequences = jr.randint(skey, (batch_size, seq_len), -seq_bound, seq_bound+1)
    idx_arr = jnp.repeat(idx, batch_size).reshape((-1, 1))

    if fixed_value is not None:
        value_col = jnp.repeat(fixed_value, batch_size)
        sequences = sequences.at[:, idx].set(value_col)
        # print(sequences[:, idx])

    inputs = jnp.concatenate([idx_arr, sequences], axis=1)

    outputs = jnp.choose(inputs[:, 0], inputs[:, 1:].T)

    return jnp.expand_dims(inputs, axis=-1), outputs

    

# NM-RNN design + training -----------------------------------------------------------

def nm_rnn(params, x0, z0, inputs, tau_x, tau_z, nln=jnp.tanh):
    """
    Arguments:
    - params
    - x0: initial state of LR-RNN
    - z0: initial state of NM network
    - inputs
    - tau_x : decay constant for LR-RNN
    - tau_z : decay constant for NM network
    """

    # U -- input dim, N -- hidden_dim, R -- rank, dim_nm -- hidden dim of NM network

    U = params["row_factors"]               # N x R
    V = params["column_factors"]            # N x R
    B_xu = params["input_weights"]          # N x U
    C = params["readout_weights"]           # O x N
    W_zz = params["nm_rec_weight"]          # dim_nm x dim_nm
    B_zu = params["nm_input_weight"]        # dim_nm x U
    W_kz = params["nm_sigmoid_weight"]      # R x dim_nm
    b_k = params["nm_sigmoid_bias"]         # R
    W_zx = params['nm_feedback_weight']     # dim_nm x N
    b_z = params['nm_feedback_bias']        # dim_nm

    # W_xz = params['input_gating_weight']    # N x dim_nm
    # b_x = params['input_gating_bias']       # N

    N = x0.shape[0]
    R = U.shape[1]

    def _step(x_and_z, u):
        x_p, z_p = x_and_z

        # update z
        z = (1.0 - (1. / tau_z)) * z_p
        z += (1. / tau_z) * W_zz @ nln(z_p)
        z += (1. / tau_z) * B_zu @ u
        z += (1. / tau_z) * (W_zx @ nln(x_p) + b_z)

        # update x
        s = jax.nn.sigmoid(W_kz @ z + b_k)
        x = (1.0 - (1. / tau_x)) * x_p
        h = V.T @ nln(x_p)
        x += (1. / (tau_x * N)) * (U * s) @ h # divide by N

        # x += (1. / tau_x) * nln(W_xz @ z + b_x)
        x += (1. / tau_x) * B_xu @ u

        # calculate y
        y = C @ x
        return (x, z), (y, x, z)

    _, (ys, xs, zs) = lax.scan(_step, (x0, z0), inputs)

    return ys, xs, zs

batched_nm_rnn = vmap(nm_rnn, in_axes=(None, None, None, 0, None, None))


def random_nmrnn_params(key, u, n, r, m, k, o, g=1.0):
    """Generate random low-rank RNN parameters

    Arguments:
    u:  number of inputs
    n:  number of neurons in main network
    r:  rank of main network
    m:  number of neurons in NM network
    k:  dimension of NM input affecting weight matrix (either 1 or r)
    o:  number of outputs
    """
    skeys = jr.split(key, 12)

    # row_factors = jnp.zeros((n,r))
    # column_factors = jnp.zeros((n,r))

    # for i in range(r):
    #     sample = jr.multivariate_normal(key, jnp.zeros((2,)), jnp.array([[1.,0.8],[0.8,1.]]), shape=(n,))
    #     # pdb.set_trace()
    #     row_factors = row_factors.at[:,i].set(sample[:,0])
    #     column_factors = column_factors.at[:,i].set(sample[:,1])

    return {'row_factors' : jr.normal(skeys[0], (n,r)) / jnp.sqrt(r), #row_factors,
            'column_factors' : jr.normal(skeys[1], (n,r)) / jnp.sqrt(r), # column_factors,
            'input_weights' : jr.normal(skeys[2], (n,u)) / jnp.sqrt(u),
            'readout_weights' : jr.normal(skeys[3], (o,n)) / jnp.sqrt(n),
            'nm_rec_weight' : jr.normal(skeys[4], (m,m)) / jnp.sqrt(m),
            'nm_input_weight' : jr.normal(skeys[5], (m,u)) / jnp.sqrt(u),
            'nm_sigmoid_weight' : jr.normal(skeys[6], (k,m)) / jnp.sqrt(m),
            'nm_sigmoid_bias' : jr.normal(skeys[7], (k,)) / jnp.sqrt(k),
            'nm_feedback_weight' : jr.normal(skeys[8], (m, n)) / jnp.sqrt(n),
            'nm_feedback_bias' : jr.normal(skeys[9], (m,)) / jnp.sqrt(m)
            # 'input_gating_weight' : jr.normal(skeys[10], (n, m)) / jnp.sqrt(m),
            # 'input_gating_bias' : jr.normal(skeys[11], (n,)) / jnp.sqrt(n)
            }

def fit_element_finder_nm_lrrnn(params, optimizer, x0, z0, num_batches, batch_size, tau_x, tau_z, seq_len, key, idx_choices=None, loss_window_size=10,
                                s0=None, num_ablated_components=None):
    opt_state = optimizer.init(params)

    @jit
    def _loss(params, batch_inputs, batch_targets):
        ys, xs, zs = None, None, None
        if num_ablated_components is None:
            ys, xs, zs = batched_nm_rnn(params, x0, z0, batch_inputs, tau_x, tau_z)
        else:
            ys, xs, zs = batched_ablated_nm_rnn(params, x0, z0, batch_inputs, tau_x, tau_z, s0, num_ablated_components)
        return jnp.mean(((ys[:, -1, 0] - batch_targets)**2))

    @jit
    def _step(params, opt_state, batch_inputs, batch_targets):
        loss_value, grads = jax.value_and_grad(_loss)(params, batch_inputs, batch_targets)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_value

    losses = []
    min_loss = None
    lowest_loss_idx = None
    lowest_loss_params = None
    for n in range(num_batches):
        key, subkey = jr.split(key, 2)

        batch_inputs, batch_targets = None, None
        if idx_choices is None:
            batch_inputs, batch_targets = generate_sequences_and_inds(batch_size, seq_len, subkey)
        else:
            batch_inputs, batch_targets = generate_sequences_and_selected_inds(batch_size, idx_choices, seq_len, subkey)

        params, opt_state, loss_value = _step(params, opt_state, batch_inputs, batch_targets)
        losses.append(loss_value)
        wandb.log({"loss": loss_value})

        if min_loss is None or loss_value < min_loss:
            min_loss = loss_value
            lowest_loss_params = params
            lowest_loss_idx = n

        if n % 500 == 0:
            print(f'step {n}, loss: {loss_value}')

    return params, lowest_loss_params, losses, min_loss, lowest_loss_idx


# LR-RNN ------------------------------------------------------------------------

def low_rank_rnn(params, x0, inputs, tau, nln=jnp.tanh):
    """
    Arguments:
    - params
    - x0
    - inputs
    - tau   : decay constant
    """

    U = params["row_factors"]       # D x R
    V = params["column_factors"]    # D x R
    B = params["input_weights"]     # D x M
    C = params["readout_weights"]   # O x D

    N = x0.shape[0]

    def _step(x, u):
        # x = (1.0 - (1. / tau)) * x
        # x += 1. / tau * U @ V.T @ nln(x)
        x = (1.0 - (1. / tau)) * x + (1. / (tau * N)) * U @ V.T @ nln(x) # divide by N
        # changed the above bc nln was being computed on (1-1/tau)x
        x += (1. / tau) * B @ u
        y = C @ x
        return x, (y, x)

    _, (ys, xs) = lax.scan(_step, x0, inputs)

    return ys, xs

batched_low_rank_rnn = vmap(low_rank_rnn, in_axes=(None, None, 0, None))

def batched_loss(params, x0, batch_inputs, tau, batch_targets):
    ys, _ = batched_low_rank_rnn(params, x0, batch_inputs, tau)
    return jnp.mean(((ys - batch_targets)**2))

def fit_element_finder_lrrnn(params, optimizer, x0, num_batches, batch_size, tau, seq_len, key):
    opt_state = optimizer.init(params)

    @jit
    def _loss(params, x0, batch_inputs, tau, batch_targets):
        ys, _ = batched_low_rank_rnn(params, x0, batch_inputs, tau)
        return jnp.mean(((ys[:, -1, 0] - batch_targets)**2))

    @jit
    def _step(params, opt_state, x0, batch_inputs, batch_targets):
        loss_value, grads = jax.value_and_grad(_loss)(params, x0, batch_inputs, tau, batch_targets)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_value

    losses = []
    min_loss = None
    lowest_loss_idx = None
    lowest_loss_params = None
    for n in range(num_batches):
        key, subkey = jr.split(key, 2)
        batch_inputs, batch_targets = generate_sequences_and_inds(batch_size, seq_len, subkey)


        params, opt_state, loss_value = _step(params, opt_state, x0, batch_inputs, batch_targets)
        losses.append(loss_value)
        wandb.log({"loss": loss_value})

        if min_loss is None or loss_value < min_loss:
            min_loss = loss_value
            lowest_loss_params = params
            lowest_loss_idx = n

        if n % 500 == 0:
            print(f'step {n}, loss: {loss_value}')

    return params, lowest_loss_params, losses, min_loss, lowest_loss_idx


def random_lrrnn_params(key, u, n, r, o, g=1.0):
  """Generate random low-rank RNN parameters"""

  skeys = jr.split(key, 4)
#   hscale = 0.1
  ifactor = 1.0 / jnp.sqrt(u) # scaling of input weights
  hfactor = g / jnp.sqrt(n) # scaling of recurrent weights
  pfactor = 1.0 / jnp.sqrt(n) # scaling of output weights
  return {'row_factors' : jr.normal(skeys[0], (n,r)) * hfactor,
          'column_factors' : jr.normal(skeys[1], (n,r)) * hfactor,
          'input_weights' : jr.normal(skeys[2], (n,u)) *  ifactor,
          'readout_weights' : jr.normal(skeys[3], (o,n)) * pfactor}

# LSTM --------------------------------------------------------------------------

def simple_lstm(params, c0, h0, inputs, gate_fn=jax.nn.sigmoid, activation_fn=jnp.tanh):
    """
    Arguments:
    - params: dict of LSTM weights and cell states
    - c0: initial cell (memory) state
    - h0: initial hidden state
    - inputs: sequence of inputs passed in, with batch dims first
    """
    # N -- hidden size, M -- input_size, O -- output_size

    # (c0, h0) = init_carry

    N = h0.shape[0] # hidden size of LSTM

    W_iu = params['weights_iu']     # N x M
    W_ih = params['weights_ih']     # N x N
    b_ih = params['bias_ih']        # scalar

    W_fu = params['weights_fu']     # N x M
    W_fh = params['weights_fh']     # N x N
    b_fh = params['bias_fh']        # scalar

    W_gu = params['weights_gu']     # N x M
    W_gh = params['weights_gh']     # N x N
    b_gh = params['bias_gh']        # scalar

    W_ou = params['weights_ou']     # N x M
    W_oh = params['weights_oh']     # N x N
    b_oh = params['bias_oh']        # scalar

    C = params["readout_weights"]   # O x N

    def _step(carry, u):
        c, h = carry
        i = gate_fn(W_iu @ u + W_ih @ h + b_ih)
        f = gate_fn(W_fu @ u + W_fh @ h + b_fh)
        g = activation_fn(W_gu @ u + W_gh @ h + b_gh)
        o = gate_fn(W_ou @ u + W_oh @ h + b_oh)

        new_c = f * c + i * g
        new_h = o * activation_fn(new_c)

        y = C @ new_h

        return (new_c, new_h), y

    if 'embedding_weights' in params:
      embedding = params['embedding_weights']
      inputs = embedding[inputs]    # make input batch_size x max_len x num_features in the case of text sentiment analysis

    carry, y = lax.scan(_step, (c0, h0), inputs)
    return carry, y

def initialize_carry(hidden_size, key):
    key1, key2 = jr.split(key)
    mem_shape = (hidden_size, )
    scale_factor = 1.0 / jnp.sqrt(hidden_size)

    c = jr.normal(key1, mem_shape) * scale_factor
    h = jr.normal(key2, mem_shape) * scale_factor
    carry = (c, h)
    return carry

# should initial carry also be batched?
batched_simple_lstm = vmap(simple_lstm, in_axes=(None, None, None, 0))


def lstm_batched_loss(params, c0, h0, batch_inputs, batch_targets):
    _, ys = batched_simple_lstm(params, c0, h0, batch_inputs)
    return jnp.mean(((ys - batch_targets)**2))

def random_lstm_params(key, u, n, o, embed_size=None, n_embeds=None):
    """Generate random low-rank RNN parameters"""

    skeys = jr.split(key, 14)
    lstm_inp = u
    if embed_size is not None:
        assert n_embeds is not None
        lstm_inp = embed_size

    ifactor = 1.0 / jnp.sqrt(lstm_inp) # scaling of input weights
    hfactor = 1.0 / jnp.sqrt(n) # scaling of recurrent weights
    pfactor = 1.0 / jnp.sqrt(n) # scaling of output weights
    params = {'weights_iu' : jr.normal(skeys[0], (n,lstm_inp)) * ifactor,
            'weights_ih' : jr.normal(skeys[1], (n,n)) * hfactor,
            'bias_ih'  : jr.normal(skeys[2], (n,)) * hfactor,
            'weights_fu' : jr.normal(skeys[3], (n,lstm_inp)) * ifactor,
            'weights_fh' : jr.normal(skeys[4], (n,n)) * hfactor,
            'bias_fh'  : jr.normal(skeys[5], (n,)) * hfactor,
            'weights_gu' : jr.normal(skeys[6], (n,lstm_inp)) * ifactor,
            'weights_gh' : jr.normal(skeys[7], (n,n)) * hfactor,
            'bias_gh'  : jr.normal(skeys[8], (n,)) * hfactor,
            'weights_ou' : jr.normal(skeys[9], (n,lstm_inp)) * ifactor,
            'weights_oh' : jr.normal(skeys[10], (n,n)) * hfactor,
            'bias_oh'  : jr.normal(skeys[11], (n,)) * hfactor,
            'readout_weights' : jr.normal(skeys[12], (o,n)) * pfactor}
    if embed_size is not None:
        params['embedding_weights'] = jr.normal(skeys[13], (n_embeds, embed_size))

    return params


def fit_element_finder_lstm(params, optimizer, init_carry, num_batches, batch_size, seq_len, key):
    opt_state = optimizer.init(params)

    @jit
    def _loss(params, c0, h0, batch_inputs, batch_targets):
        _, ys = batched_simple_lstm(params, c0, h0, batch_inputs)
        return jnp.mean(((ys[:, -1, 0] - batch_targets)**2))

    @jit
    def _step(params, opt_state, init_carry, batch_inputs, batch_targets):
        (c0, h0) = init_carry
        loss_value, grads = jax.value_and_grad(_loss)(params, c0, h0, batch_inputs, batch_targets)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_value

    losses = []
    min_loss = None
    lowest_loss_params = None
    lowest_loss_idx = None

    for n in range(num_batches):
        key, subkey = jr.split(key, 2)
        batch_inputs, batch_targets = generate_sequences_and_inds(batch_size, seq_len, subkey)

        params, opt_state, loss_value = _step(params, opt_state, init_carry, batch_inputs, batch_targets)
        losses.append(loss_value)
        wandb.log({"loss": loss_value})

        if min_loss is None or loss_value < min_loss:
            min_loss = loss_value
            lowest_loss_params = params
            lowest_loss_idx = n

        if n % 500 == 0:
            print(f'step {n}, loss: {loss_value}')

        # best_model_data = (min_loss, lowest_loss_params, lowest_loss_idx)

    return params, lowest_loss_params, losses, min_loss, lowest_loss_idx
