# This is a sample Python script.

# Press ⌃R to execute it or replace it with your code.
# Press Double ⇧ to search everywhere for classes, files, tool windows, actions, and settings.

import copy
import itertools

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


def memory_bias():
    T = 1000
    dt = 0.01

    epsilon=1e-7

    df_memory_function = pd.DataFrame()

    modes = [- (s * dt + (epsilon if s * dt > 0 else -epsilon)) for s in range(1, T+1)]
    # print(list(modes))
        
    df_memory_function[r"linear, $f(w)=-w, w>0$"] = np.abs(np.array([1/(s**2+epsilon) for s in modes]))
    df_memory_function[r"exp, $f(w)=-exp(w), w \in \mathbb{R}$"] = np.abs(np.array([1/(np.abs(s)+epsilon) for s in modes]))
    df_memory_function[r"softplus, $f(w) = -\log(1+exp(w)), w \in \mathbb{R}$"] = np.abs(np.array([(1-np.exp(s))/s**2 for s in modes]))
    df_memory_function[r"inverse, $f(w) = -1/w, w >0$"] = np.abs(np.array([1 for s in modes]))
    # print(list(df_memory_function["linear"]))
    # print(list(df_memory_function["exp"]))
    # print(list(df_memory_function["softplus"]))
    # print(list(df_memory_function["inverse"]))

    plt.clf()
    plt.figure(1, figsize=(12, 10), dpi=450)

    line_styles = ["-", "--", "-.", ":"]
    line_styles_cycle = itertools.cycle(line_styles)
    line_widths = [item * 0.3 + 2.5 for item in range(4)]
    line_widths_cycle = itertools.cycle(line_widths)

    for column, line_style, line_width in zip(
        df_memory_function.columns, line_styles_cycle, line_widths_cycle
    ):
        plt.plot(
            modes,
            df_memory_function[column],
            line_style,
            linewidth=line_width,
            label=column,
        )

    plt.xlabel(r"Mode $\lambda$", fontsize=16)
    plt.ylabel(r"Gradient scale at mode $\lambda$: $G_f(\lambda)$", fontsize=14)
    # plt.ylim([0, 0.035])
    # plt.xscale("log")
    plt.yscale("log")
    plt.legend()
    plt.tight_layout()
    # plt.title(r"", fontsize=14)
    plt.subplots_adjust(top=0.93)
    plt.savefig(f"./gradient_scale.pdf")
    plt.savefig(f"./gradient_scale.png")


def memory_bias_weight():
    T = 2000
    dt = 0.005

    epsilon=1e-7

    df_memory_function = pd.DataFrame()

    weights = [s * dt + (epsilon if s * dt > 0 else -epsilon) for s in range(1-T, T+1)]
    # print(list(weights))
    
    # TODO
    df_memory_function[r"linear, $f(w)=-w, w>0$"] = np.abs(np.array([1/(w**2+epsilon) for w in weights]))
    df_memory_function[r"exp, $f(w)=-\exp(w), w \in \mathbb{R}$"] = np.abs(np.array([1/(np.exp(w)+epsilon) for w in weights]))
    df_memory_function[r"softplus, $f(w) = -\log(1+exp(w)), w \in \mathbb{R}$"] = np.abs(np.array([(1 / (1 + np.exp(-w)))/np.log(1+np.exp(w))**2 for w in weights]))
    df_memory_function[r"inverse, $f(w) = -\frac{1}{w}, w >0$"] = np.abs(np.array([1 for w in weights]))
    # print(list(df_memory_function["linear"]))
    # print(list(df_memory_function["exp"]))
    # print(list(df_memory_function["softplus"]))
    # print(list(df_memory_function["inverse"]))

    plt.clf()
    plt.figure(1, figsize=(12, 10), dpi=450)

    line_styles = ["-", "--", "-.", ":"]
    line_styles_cycle = itertools.cycle(line_styles)
    line_widths = [item * 0.3 + 2.5 for item in range(4)]
    line_widths_cycle = itertools.cycle(line_widths)

    for column, line_style, line_width in zip(
        df_memory_function.columns, line_styles_cycle, line_widths_cycle
    ):
        plt.plot(
            weights,
            df_memory_function[column],
            line_style,
            linewidth=line_width,
            label=column,
        )

    plt.xlabel(r"Weight $w$", fontsize=16)
    plt.ylabel(r"Gradient scale at weight $w$: $G_f(w)$", fontsize=14)
    # plt.ylim([0, 0.035])
    # plt.xscale("log")
    plt.yscale("log")
    plt.legend()
    plt.tight_layout()
    # plt.title(r"", fontsize=14)
    plt.subplots_adjust(top=0.93)
    plt.savefig(f"./gradient_scale_weight.pdf")
    plt.savefig(f"./gradient_scale_weight.png")


def memory_bias_discrete():
    T = 2000
    dt = 0.0005
    epsilon=1e-7 

    df_memory_function = pd.DataFrame()
    
    modes = [s * dt + (epsilon if s * dt > 0 else -epsilon) for s in range(1-T, T+1)]
    # print(list(modes))

    df_memory_function[r"linear, $f(w)=w, w \in (-1, 1)$"] = np.array([1 / ((1-s)**2 + epsilon) for s in modes])
    df_memory_function[r"sigmoid, $f(w) = \frac{1-\exp(-w)}{1+\exp(-w)}, w \in \mathbb{R}$"] = np.array([2 * (1+s) / ((1-s)**2+epsilon) for s in modes])
    df_memory_function[r"lru, $f(w)=\exp(-\exp(w)), w \in \mathbb{R}$"] = np.array([s * np.abs(np.log(s)) / ((1-s)**2+epsilon) for s in modes])
    df_memory_function[r"inverse, $f(w) = 1 - \frac{1}{w}, w > 0.5$"] = np.array([1 for s in modes])

    # print(list(df_memory_function["linear"]))
    # print(list(df_memory_function["sigmoid"]))
    # print(list(df_memory_function["lru"]))
    # print(list(df_memory_function["inverse"]))

    plt.clf()
    plt.figure(1, figsize=(12, 10), dpi=450)

    line_styles = ["-", "--", "-.", ":"]
    line_styles_cycle = itertools.cycle(line_styles)
    line_widths = [item * 0.3 + 2.5 for item in range(4)]
    line_widths_cycle = itertools.cycle(line_widths)

    for column, line_style, line_width in zip(
        df_memory_function.columns, line_styles_cycle, line_widths_cycle
    ):
        plt.plot(
            modes,
            df_memory_function[column],
            line_style,
            linewidth=line_width,
            label=column,
        )

    plt.xlabel(r"Mode $\lambda$", fontsize=16)
    plt.ylabel(r"Gradient scale at mode $\lambda$: $G_f(\lambda)$", fontsize=14)
    # plt.ylim([0, 0.035])
    # plt.xscale("log")
    plt.yscale("log")
    plt.legend()
    plt.tight_layout()
    # plt.title(r"", fontsize=14)
    plt.subplots_adjust(top=0.93)
    plt.savefig(f"./gradient_scale_discrete.pdf")
    plt.savefig(f"./gradient_scale_discrete.png")


def memory_bias_discrete_weight():
    T = 2000
    dt = 0.002
    epsilon=1e-7 

    df_memory_function = pd.DataFrame()
    
    weights = [s * dt + (epsilon if s * dt > 0 else -epsilon) for s in range(1-T, T+1)]
    # print(list(weights))

    df_memory_function[r"linear, $f(w)=w, w \in (-1, 1)$"] = [
        np.nan if abs(w) > 1 else 1 / ((1 - w) ** 2 + epsilon) for w in weights
    ]
    df_memory_function[r"sigmoid, $f(w) = \frac{1-\exp(-w)}{1+\exp(-w)}, w \in \mathbb{R}$"] = np.array([2 * np.exp(-w) for w in weights])
    df_memory_function[r"lru, $f(w)=\exp(-\exp(w)), w \in \mathbb{R}$"] = np.array([np.exp(w-np.exp(w)) / ((1-np.exp(-np.exp(w)))**2+epsilon) for w in weights])
    df_memory_function[r"inverse, $f(w) = 1 - \frac{1}{w}, w > 0.5$"] = np.array([np.nan if w < 0.5 else 1 for w in weights])

    # print(list(df_memory_function["linear"]))
    # print(list(df_memory_function["sigmoid"]))
    # print(list(df_memory_function["lru"]))
    # print(list(df_memory_function["inverse"]))

    plt.clf()
    plt.figure(1, figsize=(12, 10), dpi=450)

    line_styles = ["-", "--", "-.", ":"]
    line_styles_cycle = itertools.cycle(line_styles)
    line_widths = [item * 0.3 + 2.5 for item in range(4)]
    line_widths_cycle = itertools.cycle(line_widths)

    for column, line_style, line_width in zip(
        df_memory_function.columns, line_styles_cycle, line_widths_cycle
    ):
        plt.plot(
            weights,
            df_memory_function[column],
            line_style,
            linewidth=line_width,
            label=column,
        )

    plt.xlabel(r"Weight $w$", fontsize=16)
    plt.ylabel(r"Gradient scale at weight $w$: $G_f(w)$", fontsize=14)
    # plt.ylim([0, 0.035])
    # plt.xscale("log")
    plt.yscale("log")
    plt.legend()
    plt.tight_layout()
    # plt.title(r"", fontsize=14)
    plt.subplots_adjust(top=0.93)
    plt.savefig(f"./gradient_scale_discrete_weight.pdf")
    plt.savefig(f"./gradient_scale_discrete_weight.png")


# Press the green button in the gutter to run the script.
if __name__ == "__main__":
    memory_bias()
    memory_bias_weight()
    memory_bias_discrete()
    memory_bias_discrete_weight()



# See PyCharm help at https://www.jetbrains.com/help/pycharm/
