# This is a sample Python script.

# Press ⌃R to execute it or replace it with your code.
# Press Double ⇧ to search everywhere for classes, files, tool windows, actions, and settings.

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import torch

from s4 import S4Block
from layers import SimpleRNN, TCN


class MyLSTM(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(MyLSTM, self).__init__()

        self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.read_out = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.read_out(x)
        return x


class MyGRU(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(MyGRU, self).__init__()

        self.gru = torch.nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.read_out = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x, _ = self.gru(x)
        x = self.read_out(x)
        return x


class MySSM(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(MySSM, self).__init__()

        self.ssm = torch.nn.RNN(input_size, hidden_size, 1, batch_first=True)
        self.activation = torch.nn.Tanh()

        assert num_layers >= 1
        self.ssm_layers = torch.nn.ModuleList(
            [
                torch.nn.RNN(hidden_size, hidden_size, 1, batch_first=True)
                for _ in range(num_layers - 1)
            ]
        )

        self.read_out = torch.nn.Linear(hidden_size, output_size)

        self.layers = num_layers

    def forward(self, x):
        x, _ = self.ssm(x)
        x = self.activation(x)

        for layer in self.ssm_layers:
            x, _ = layer(x)
            x = self.activation(x)

        x = self.read_out(x)
        return x


def rnn_constructor(input_size, hidden_size, output_size, num_layers, name="rnn"):
    if name == "rnn":
        model = SimpleRNN(input_size, hidden_size, output_size, num_layers)
        return_state = False
    elif name == "lstm":
        model = MyLSTM(input_size, hidden_size, output_size, num_layers)
        return_state = False
    elif name == "gru":
        model = MyGRU(input_size, hidden_size, output_size, num_layers)
        return_state = False
    elif name == "myssm":
        model = MySSM(input_size, hidden_size, output_size, num_layers)
        return_state = False

    return model, return_state


def cnn_constructor(
    input_channels,
    n_classes,
    num_layers,
    channel_sizes=[30],
    kernel_size=7,
    dropout=0.0,
):
    model = TCN(
        input_channels,
        n_classes,
        channel_sizes * num_layers,
        kernel_size=kernel_size,
        dropout=dropout,
    )
    return_state = False

    return model, return_state


def transformer_constructor(input_size, num_heads, hidden_size):
    model = torch.nn.TransformerEncoderLayer(
        input_size, num_heads, hidden_size, dropout=0.0, batch_first=True
    )
    return_state = False

    return model, return_state


def ssm_constructor(hidden_size, layers=1, final_act="gelu"):
    module_list = torch.nn.ModuleList(
        [
            S4Block(hidden_size, final_act=final_act, transposed=False)
            for _ in range(layers)
        ]
    )
    model = torch.nn.Sequential(*module_list)
    return_state = False

    return model, return_state


def memory_evaluate(model, T, return_state):
    x = torch.zeros(1, T, 1)
    x[:, 0] = 1

    output = model(x)

    if return_state:
        return output[0]
    else:
        return output


import seaborn as sns
import pandas as pd
import numpy as np
import torch

import numpy as np
import torch
import matplotlib.pyplot as plt


def rnn_plot(name="rnn"):
    repetitions = 5
    layers_list = [1, 2, 3, 4]
    # layers_list = [1]
    T = 128

    memories = np.zeros((len(layers_list), repetitions, T))

    for layer_idx, layers in enumerate(layers_list):
        for i in range(repetitions):
            B, T, d = 1, 128, 1
            x = torch.randn(B, T, d)

            model, return_state = rnn_constructor(1, T, 1, layers, name)
            memory = memory_evaluate(model, T, return_state)
            memory = torch.squeeze(memory).detach().numpy()

            memory -= memory[-1]  # Make tail = 0
            memory = np.abs(memory)  # Make value positive
            memory /= np.sum(memory)  # Normalization

            memories[layer_idx, i, :] = memory

    plt.clf()
    for layer_idx, layers in enumerate(layers_list):
        memory_mean = memories[layer_idx].min(axis=0)
        memory_std = memories[layer_idx].std(axis=0)
        times = np.arange(T)
        plt.plot(times, memory_mean, label=f"layers={layers}")
        # plt.fill_between(times, memory_mean, memory_mean + memory_std, alpha=0.2)

    plt.xlabel(r"time $t$", fontsize=14)
    plt.ylabel(r"memory $\rho$", fontsize=14)
    plt.xscale("log")
    plt.yscale("log")
    plt.legend()
    plt.savefig(f"./{name}_multilayers.pdf")


def tcn_plot():
    memories = pd.DataFrame()

    for layers in [1, 2, 3, 4, 5]:
        B, T, d = 1, 128, 1
        x = torch.randn(B, T, d)

        model, return_state = cnn_constructor(1, 1, layers)
        memory = memory_evaluate(model, T, return_state)
        memory = torch.squeeze(memory).detach().numpy()
        # memory = memory / np.max(np.abs(memory))

        memory -= memory[-1]  # Make tail = 0
        memory = np.abs(memory)  # Make value positive
        memory /= np.sum(memory)  # Normalization

        memories[f"layers={layers}"] = memory

    plt.clf()
    plt.plot(np.arange(T), memories)
    plt.xlabel(r"time $t$", fontsize=14)
    plt.ylabel(r"memory $\rho$", fontsize=14)
    # plt.xscale("log")
    plt.yscale("log")
    plt.legend(memories.keys())
    plt.savefig("./tcn_multilayers.pdf")


def ssm_plot(final_act="gelu", epsilon=1e-10):
    memories = pd.DataFrame()

    for layers in [1, 2, 3]:
        B, T, d = 1, 2048, 1
        x = torch.randn(B, T, d)

        hid_dim = 1

        model, return_state = ssm_constructor(hid_dim, layers, final_act=final_act)
        memory = memory_evaluate(model, T, return_state)
        # print(memory.shape)
        memory = torch.squeeze(memory).detach().numpy()
        # memory = memory / np.max(np.abs(memory))

        # print(memory.shape)
        memory -= np.mean(memory[-1])  # Make tail = 0
        memory = np.abs(memory)  # Make value positive
        memory /= np.sum(memory) + epsilon  # Normalization

        memories[f"layers={layers}"] = memory

    plt.clf()
    plt.plot(np.arange(T)[::1], memories[::1])
    plt.xlabel(r"time $t$", fontsize=14)
    plt.ylabel(r"memory $\rho$", fontsize=14)
    # plt.xscale("log")
    plt.yscale("log")
    plt.legend(memories.keys())
    plt.savefig(f"./ssm_multilayers_{final_act}.pdf")


# Press the green button in the gutter to run the script.
if __name__ == "__main__":
    np.random.seed(1234)
    torch.random.manual_seed(1234)

    rnn_plot()
    rnn_plot("lstm")
    rnn_plot("gru")
    rnn_plot("myssm")
    tcn_plot()
    ssm_plot("gelu")
    ssm_plot("linear")
    ssm_plot("tanh")
    ssm_plot("relu")
    ssm_plot("sigmoid")
    ssm_plot("softplus")
