import numpy as np
import pickle
import scipy
import torch
from torch import nn
from RNNvanilla import *
from torch.optim.lr_scheduler import *

import pdb

dat = scipy.io.loadmat('data/condsForSimJ2moMuscles.mat')
n = 300
g = 1.5
h = 1.

# With Euler integration, units here won't matter, the ratio sets the 
# effective dt
dt = 5e-3
tau = 5e-2

num_inputs = 16
num_outputs = 8

# Preceeding the go cue, when the network should actually repond (used in the loss function)
start_index = 180

activation_func = nn.ReLU()

def gen_model(A, seed=0):
    # Assemble model
    init_rng = np.random.default_rng(seed)
    B = init_rng.normal(size=(n, num_inputs), scale=h/np.sqrt(num_inputs))
    C = np.zeros((num_outputs, n))
    state_bias = np.zeros((n,))
    output_bias = np.zeros((num_outputs,))
    init_state = np.zeros((n,))
    return {'weight_matrix':A, 'init_state':init_state,
            'output_weight_matrix':C, 'input_weight_matrix':B,
            'state_bias':state_bias, 'output_bias':output_bias}

def get_inputs(dat):
    u = []
    targets = []
    n_conditions = dat['condsForSim']   .shape[0]
    n_delays = dat['condsForSim'].shape[1]
    for i in range(n_conditions):
        for j in range(n_delays):
            u1 = dat['condsForSim'][i, j]['plan']
            u2 = dat['condsForSim'][i, j]['goEnvelope']
            u.append(np.hstack([u1, u2]))
            targets.append(dat['condsForSim'][i, j]['muscle'])

    u = np.array(u)
    targets = np.array(targets)
    return u, targets, n_conditions, n_delays

if __name__ == '__main__':

    # Load A matrices
    with open('Alist_rnn.pkl', 'rb') as f:
        Alist = pickle.load(f)

    dat = scipy.io.loadmat('data/condsForSimJ2moMuscles.mat')

    trainables = [['weights', 'outputs', 'inputs', 'state_bias', 'output_bias']]

    u, targets, _, _ = get_inputs(dat)
    lr = 1e-3
    max_lr = 1e-2

    for i1 in range(len([Alist])):
        for i2 in range(len(Alist[i1])):
            for j, train in enumerate(trainables):
                params = {'tau':tau, 'activation_func':activation_func}
                params.update(gen_model(Alist[i1][i2]))
                rnn = DaleRNN(**params, train=trainables[0])
                rnn = train_model(rnn, targets, u, dt=dt, num_epochs=5000, start_index=start_index, 
                                lr=lr, lr_scheduler=None)        
                # Save model to disk
                torch.save(rnn, 'Dale_models_Areps/model_%d_%d_%d.pkl' % (i1, i2, j))

