#!/usr/bin/env python3

import csv
import matplotlib
import matplotlib.pyplot as plt
import numpy as np


def generatePlot(
    filename,
    groundtruth,
    observations,
    outputs,
    outputs_at
) :

    tic_font_size = "xx-small"
    label_font_size = "x-small"

    plt.rc("font", family = "sans serif")
    ## for Palatino and other serif fonts use:
    #rc('font',**{'family':'serif','serif':['Palatino']})
    plt.rc("text", usetex = False)

    plt.rc("xtick", labelsize = tic_font_size)
    plt.rc("ytick", labelsize = tic_font_size)
    plt.rc("axes", axisbelow = True)
    matplotlib.rcParams["lines.linewidth"] = 0.8
    matplotlib.rcParams["legend.fancybox"] = True



    fig_size = (6.75, 2.0)


    blue2    = '#2e518c'
    blue3    = '#5079b3'
    blue4    = '#7da7d9'
    green1   = '#146614'
    green2   = '#2e8c2e'
    green3   = '#50b350'
    green4   = '#7dd97d'
    green5   = '#b3ffb3'
    red1     = '#660000'
    red2     = '#8c1919'
    red3     = '#b33e3e'
    red4     = '#d97272'
    red5     = '#ffb3b3'
    magenta1 = '#581466'
    magenta2 = '#762e8c'
    magenta3 = '#9650b3'
    magenta4 = '#b87dd9'
    magenta5 = '#dfb3ff'
    orange1  = '#b34e0b'
    orange2  = '#c67322'
    orange3  = '#d99a3d'
    orange4  = '#ecc05c'
    orange5  = '#ffe480'
    cyan1    = '#146666'
    cyan2    = '#2e8c8c'
    cyan3    = '#50b3b3'
    cyan4    = '#7dd9d9'
    cyan5    = '#b3ffff'
    gray1    = '#4d4d4d'
    gray2    = '#6c6c6c'
    gray3    = '#8c8c8c'
    gray4    = '#acacac'
    gray5    = '#cccccc'



    color_palette = [
        blue2,
        blue4,
        "seagreen",
        "lightseagreen",
        "sienna",
        "sandybrown",
        magenta3,
        "black"
    ]


    # ----------------------------------------------------------------
    
    data_length = len(groundtruth)
    
    # ----------------------------------------------------------------

    #x = range(data_length)

    plots_num = 3

    fig, axs = plt.subplots(
        1, plots_num, figsize = fig_size,
        sharex = True, sharey=True, gridspec_kw={'wspace': 0}
    )

    #if plots_num == 1:
    #   axs = [axs]

    fig.patch.set_facecolor(None)
    fig.patch.set_alpha(0.0)


    #axs[0].plot(x, observations, color = "seagreen", label = "Observations")
    
    l1 = axs[0].plot(groundtruth[:,0], groundtruth[:,1], color = "black", label = "Ground truth", linewidth=0.5, linestyle="--",dashes=(5, 1), zorder=10)[0]
    l2 = axs[0].plot(observations[:,0], observations[:,1], color = gray4, label = "Observed signal", zorder=1)[0]
    l3 = axs[1].plot(outputs[:,0], outputs[:,1], color = orange3, label = "Networt output")[0]
    l4 = axs[2].plot(outputs_at[:,0], outputs_at[:,1], color = blue3, label = "Network output (AT)")[0]

    axs[0].plot([groundtruth[0,0]], [groundtruth[0,1]], label="_nolegend_", color = "black", linewidth=0.5, marker=".", zorder=10)
    

    #axs[0].plot(x, groundtruth, color = "black", label = "Ground truth", linewidth=0.5)
    
    #axs[0].plot(x, groundtruth, color = "black", label = "Ground truth", linewidth=0.5, zorder=10, linestyle="--",dashes=(5, 1))
    #axs[0].plot(x, outputs, color = orange3, label = "Network output", zorder=1)
    #axs[0].plot(x, outputs_at, color = blue3, label = "Networt output (AT)", zorder=2)
    
    #axs[1].plot(x, data[2], color = color_palette[2], label = labels[2])
    #axs[2].plot(x, data[3], color = color_palette[3], label = labels[3])
    #axs[2].plot(x, data[3], color = color_palette[4], label = labels[3])
    #axs[3].plot(x, data[4], color = color_palette[5], label = labels[4])
    #axs[4].plot(x, data[5], color = color_palette[6], label = labels[5])

    #matplotlib.pyplot.xlim(data_start, data_end)
    #plt.xticks(np.linspace(data_start, data_end, 20))

    # Hide x labels and tick labels for all but bottom plot.

    for i in range(len(axs)) :
        ax = axs[i]

        #ax.set(xlabel="x")
        ax.set_xlabel("x", fontsize=label_font_size)
        
        ax.patch.set_facecolor("white")
        ax.patch.set_alpha(0.8)
        ax.grid(linewidth = 0.5, linestyle = "dashed", zorder = 0)
        
        ax.tick_params(
            direction = "in", 
            bottom = True, top = True,
            left = True, right = True,
            zorder = 3
        )
        
        #switch_pos = 517

        #ax.axvspan(switch_pos - 6.5, switch_pos + 6.5, color = "grey", linewidth = 0, alpha = 0.3)

        #ax.legend(fontsize = tic_font_size, loc = "upper right", ncol=3)
        ax.label_outer()
        #ax.xaxis.set_ticklabels([])
        #ax.yaxis.set_major_formatter(matplotlib.ticker.StrMethodFormatter(ax_formats[i]))
    
    #ax.legend(fontsize = tic_font_size, loc = "upper right", ncol=3)
    #plt.show()
    axs[0].set_ylabel("y", fontsize=label_font_size)
    

    labels = [
        "Ground truth",
        "Observed signal",
        "Networt output",
        "Network output (AT)"
    ]

    fig.legend([l1, l2, l3, l4],
       labels=labels,
       fontsize = tic_font_size,
       loc="upper center",
       bbox_to_anchor=(0.45, 0.9),
       #borderaxespad=0.1, 
       ncol=4
    )
 


    plt.savefig(filename, bbox_inches = 'tight', pad_inches = 0)
    
'''
data1 = np.zeros((200, 2))
    
for i in range(200) :
    l = i / (200 - 1)
    a = -i * 0.1
    x = 1.0 * np.cos(a) - 0.0 * np.sin(a)
    y = 1.0 * np.sin(a) + 0.0 * np.cos(a)
    data1[i][0] = l * x
    data1[i][1] = l * y

generatePlot(
    filename="test.pdf",
    groundtruth=data1,
    observations=data1,
    outputs=data1,
    outputs_at=data1
)
'''
     