import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl

mpl.rcParams.update(mpl.rcParamsDefault)
sns.set(style='whitegrid')
colors = sns.color_palette("Set2")
plt.rcParams["figure.dpi"] = 300
plt.rcParams["figure.figsize"] = (4, 2)
plt.rcParams["axes.labelsize"] = 14
plt.rcParams["axes.titlesize"] = 14
plt.rcParams["xtick.labelsize"] = 12
plt.rcParams["ytick.labelsize"] = 12
plt.rcParams["legend.fontsize"] = 10  # 10
plt.rcParams["lines.linewidth"] = 1

methods = ['lru', 'ffm', 'lifgate', 'sigmoidgate']

name = {
    'lifgate': 'Ours',
    'sigmoidgate': 'Ours w/ sigmoid',
    'ffm': 'FFM',
    'lru': 'LRU'
}

o_grads = {}
plt.figure(figsize=(4, 4))

for i, m in enumerate(methods):
    o_grad = pd.read_csv(f'visual_match_250_{m}_1000.csv', header=None)
    matrix = np.reshape(o_grad.values[:, 1], (-1, 280))
    snrs = []
    x = np.arange(matrix.shape[0])
    for j in range(matrix.shape[0]):
        rc_in = matrix[j][:15].mean()
        rc_noise = matrix[j][15:int(250) + 15].mean()
        SNR = rc_in / (rc_in + rc_noise + 1e-9)
        snrs.append(SNR)
        print(f'SNR={round(SNR, 3)}')
    plt.plot(x, snrs, marker='o', linewidth=2.0, label=name[m], color=colors[i])
    print(matrix)
    o_grads[m] = matrix[-1]

plt.legend()
plt.xticks([9, 19], labels=[500, 1000])
plt.xlim(0, 19)
plt.ylim(0, 1)
plt.tight_layout()
plt.savefig(f'plt/snr.pdf')

plt.figure(figsize=(9, 4))

for i, m in enumerate(methods):
    plt.subplot(2, 2, i+1)
    x = np.arange(280)
    y = o_grads[m]

    plt.fill_between(x[:15], 0, y[:15], color='#c5e0b4')
    plt.fill_between(x[15:265], 0, y[15:265], color='#ffe699')
    plt.fill_between(x[265:], 0, y[265:], color='#b4c7e7')

    plt.title(name[m])

    plt.xlim(0, 280)
    plt.ylim(0, 0.2)
    plt.tight_layout()

plt.savefig(f'plt/contribution.pdf')