import copy
from src import models as models
import numpy as np
import torch
import matplotlib.pyplot as plt


S = 10  # number of elements to be encoded/decoded
h_size = 2*(S+1)

__fixed_seeds__ = {"train": 0, "test": 1, "valid": 2, "image": 3}  # obviates need to store data
device = torch.device('cpu')
kwargs = {
    "alpha": 0,
    "architecture": "GRU",
    "batch_first": True,
    "dropout": 0,
    "gain": 1,
    "hidden_size": h_size,
    "initializer": "orthogonal",
    "input_size": 4,
    "max_angle": 2,
    "min_angle": 0,
    "num_layers": 1,
    "output_size": 3
}

model = models.DirectModel(output_only_last=False, **kwargs)  # model properly configured

a = 10  # first parameter  (arbitrary as long as a >> b)
b = 3  # second parameter
encode_seq_ratio = 1/3  # 1/2 doesn't work

new_state_dict = copy.deepcopy(model.state_dict())

# First Make W_h
Lambda_fill1 = torch.zeros((2*(S+1), 4))
Lambda_fill2 = torch.zeros((2*(S+1), 4))
Lambda11 = torch.zeros((S, 2))
Lambda12 = torch.zeros((S, 2))
Lambda21 = torch.zeros((S+1, 2))
Lambda22 = torch.zeros((S+1, 2))
Lambda31 = torch.zeros((1, 2))
Lambda32 = torch.zeros((1, 2))

Lambda21[-2:, :] = -a

for i in range(Lambda11.shape[0]):
    if i == 0:
        Lambda31[i, :] = -a
        Lambda32[i, :] = a

    map = encode_seq_ratio**(i+1)

    Lambda11[i, 0] = map
    Lambda11[i, 1] = -map

    if i == Lambda11.shape[0] - 1:
        Lambda22[i+1, 0] = -a
        Lambda22[i+1, 1] = a

start = 2*2*(S+1)
start_r = 2*(S+1)

new_state_dict['rnn.weight_ih_l0'].data[start:start+S, :2] = Lambda11
new_state_dict['rnn.weight_ih_l0'].data[start:start+S, 2:] = Lambda12
new_state_dict['rnn.weight_ih_l0'].data[start+S:start+2*S+1, :2] = Lambda21
new_state_dict['rnn.weight_ih_l0'].data[start+S:start+2*S+1, 2:] = Lambda22
new_state_dict['rnn.weight_ih_l0'].data[start+2*S+1:, :2] = Lambda31
new_state_dict['rnn.weight_ih_l0'].data[start+2*S+1:, 2:] = Lambda32


H11 = torch.zeros((S, S))
H12 = torch.zeros((S, S+1))
H13 = torch.zeros((S, 1))
H21 = torch.zeros((S+1, S))
H22 = torch.zeros((S+1, S+1))
H23 = torch.zeros((S+1, 1))
H31 = torch.zeros((1, S))
H32 = torch.zeros((1, S+1))
H33 = torch.zeros((1, 1))
for j in range(H22.shape[0]):
    if j == H22.shape[0] - 1:
        pass
    else:
        H22[j, j+1] = a

Hr12 = torch.zeros((S, S+1))
for k in range(Hr12.shape[0]):
    Hr12[k, -2 - k] = -a


new_state_dict['rnn.weight_hh_l0'].data[start:start+S, :S] = H11
new_state_dict['rnn.weight_hh_l0'].data[start:start+S, S:2*S+1] = H12
new_state_dict['rnn.weight_hh_l0'].data[start:start+S, 2*S+1:] = H13
new_state_dict['rnn.weight_hh_l0'].data[start+S:start+2*S+1, :S] = H21
new_state_dict['rnn.weight_hh_l0'].data[start+S:start+2*S+1, S:2*S+1] = H22
new_state_dict['rnn.weight_hh_l0'].data[start+S:start+2*S+1, 2*S+1:] = H23
new_state_dict['rnn.weight_hh_l0'].data[start+2*S+1:, :S] = H31
new_state_dict['rnn.weight_hh_l0'].data[start+2*S+1:, S:2*S+1] = H32
new_state_dict['rnn.weight_hh_l0'].data[start+2*S+1:, 2*S+1:] = H33

new_state_dict['rnn.weight_hh_l0'].data[start_r:start_r+S, S:2*S+1] = Hr12

V11 = torch.zeros((2, S))
V12 = torch.zeros((2, S+1))
V13 = torch.zeros((2, 1))
V21 = torch.zeros((1, S))
V22 = torch.zeros((1, S+1))
V23 = torch.zeros((1, 1))
V23[0, 0] = -a

for ii in range(V11.shape[1]):
    if ii % 2 == 0:
        V11[0, ii] = 1
    else:
        V11[1, ii] = -1
for jj in range(V22.shape[1]):
    if jj < V22.shape[1] - 1:
        V22[0, jj] = -a

new_state_dict['classifier.weight'].data[:2, :S] = V11
new_state_dict['classifier.weight'].data[:2, S:2*S+1] = V12
new_state_dict['classifier.weight'].data[:2, 2*S+1:] = V13
new_state_dict['classifier.weight'].data[2:, :S] = V21
new_state_dict['classifier.weight'].data[2:, S:2*S+1] = V22
new_state_dict['classifier.weight'].data[2:, 2*S+1:] = V23

new_state_dict['classifier.bias'].data = torch.tensor([0, 0, ((3-2*S)/2)*a])


new_state_dict['rnn.bias_ih_l0'].data = torch.zeros(3*h_size)
new_state_dict['rnn.bias_hh_l0'].data = torch.zeros(3*h_size)

b1 = torch.zeros(S)
b1[:] = b
b2 = torch.zeros(S+2)
b2[:] = -b

new_state_dict['rnn.bias_hh_l0'].data[start_r:start_r+S] = b1
new_state_dict['rnn.bias_hh_l0'].data[start_r+S:start_r+2*S+2] = b2


model.load_state_dict(new_state_dict)
model.double()

########################################################################################################################
########################################################################################################################


def generate_data(min_lag=100, max_lag=120, num=100, set='test', slen=2, elements=2):
    assert max_lag >= min_lag

    random = np.random.RandomState(__fixed_seeds__[set])
    seq = random.randint(0, high=elements, size=(num, slen))

    inp = np.zeros((num, max_lag + 2*slen), dtype='int64')
    targ = np.zeros((num, max_lag + 2*slen), dtype='int64')

    inp[:, 0:slen] = seq

    for i in range(num):

        T = random.randint(min_lag, max_lag)
        blank0 = np.full(slen, fill_value=elements)
        blank1 = np.full(T - 1, fill_value=elements)
        delim = np.full(1, fill_value=elements+1)
        delim2 = np.full(1, fill_value=elements)
        blank2 = np.full(max_lag + slen - T, fill_value=elements)
        blank3 = np.full(max_lag - T, fill_value=elements)
        inp[i] = np.concatenate((seq[i], blank1, delim, blank2))
        targ[i] = np.concatenate((blank0, blank1, delim2, seq[i], blank3))

    return inp, targ


def to_OHE(slen=S, elements=2):
    inp, targ = generate_data(slen=slen, elements=elements)
    for i in range(inp.shape[0]):
        trial = inp[i, :]
        trial_OHE = np.zeros((trial.size, trial.max() + 1))
        trial_OHE[np.arange(trial.size), trial] = 1
        trial_OHE = trial_OHE.reshape(1, trial_OHE.shape[0], trial_OHE.shape[1])

        out_trial = targ[i, :]
        out_trial_OHE = np.zeros((out_trial.size, out_trial.max() + 1))
        out_trial_OHE[np.arange(out_trial.size), out_trial] = 1
        out_trial_OHE = out_trial_OHE.reshape(1, out_trial_OHE.shape[0], out_trial_OHE.shape[1])

        if i == 0:
            X = trial_OHE
            y = out_trial_OHE
        else:
            X = np.concatenate((X, trial_OHE), axis=0)
            y = np.concatenate((y, out_trial_OHE), axis=0)

    X = torch.as_tensor(X).double()
    y = torch.as_tensor(y).double()

    return X, y


def plot_hiddens(hiddens, lower=0, upper=h_size):

    hiddens_matrix = hiddens

    fig, ax1 = plt.subplots()
    plt.ylim(lower, upper+0.8)
    plt.yticks(np.arange(lower+1, upper+1, 1.0))
    plt.xlabel('Time-Steps')
    plt.ylabel('Hidden-State Dimension Index')
    ax1.xaxis.label.set_fontsize(18)
    ax1.yaxis.label.set_fontsize(18)
    plt.axvline(x=1, color='#848FA2')
    ax2 = ax1.twinx()
    ax1.tick_params(left=False)

    for j in range(len(hiddens_matrix[1])):  # iterate across sequences
        for i in range(lower, upper):  # iterate over hidden-state dimensions
            ax2.plot(hiddens_matrix[:, j, i] + 3 * i, linewidth=1, color='#2D3142', alpha=0.05)  # alpha = 0.05

    ax2.axes.get_yaxis().set_visible(False)
    fig.tight_layout()
    plt.savefig("toy.pdf")
    plt.show()


########################################################################################################################
########################################################################################################################

X, y = to_OHE()

h0 = np.zeros((1, 1, h_size))

h0[0, 0, :S] = 0
h0[0, 0, S:] = -1
h0[0, 0, -3] = 1

h0s = np.tile(h0, [1, X.shape[0], 1])
h0s = torch.from_numpy(h0s).to(device=X.device, dtype=X.dtype).double()


idx = 26  # choose a test trial

Y_hat, hidden = model.forward(X, hidden_init=h0s, return_for_inspection=True)

# Input Sequence
print(torch.argmax(X[idx, :, :], axis=1))

# Desired Output Sequence
print(torch.argmax(y[idx, :, :], axis=1))

# RNN Output Sequence
print(torch.argmax(Y_hat[idx, :, :], axis=0))

hidden = hidden.detach().numpy()
hidden = np.transpose(hidden, (1, 0, 2))
plot_hiddens(hiddens=hidden)

