import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np
import seaborn as sns

def make_plots(pde,ode,system_name,plot_kernel=True,plot_PDF=True,plot_loss=True,make_gif=True,plot_intermediates=True,show_plots=True,x_lim=None):
    z_i  = pde.z_i
    rho0 = pde.rho0
    rho  = pde.rho
    x    = ode.x
    nT   = pde.nT
    g0   = pde.g0
    x0   = ode.x0
    nRho = pde.nRho
    dt   = ode.dt
    timestamps = ode.saveat
    if plot_intermediates:
        rho_history = pde.rho_history
        x_history   = ode.x_history
    if make_gif:
        rho_history = pde.rho_history

    ########## plot kernel #################
    cropped_z_i = z_i[z_i<-z_i[0]]
    try:
        W_vec = np.vectorize(pde.W)
        val = W_vec(cropped_z_i,0,1.)
    except:
        print("in exception")
        val = pde.W(0,cropped_z_i,1.)
    if plot_kernel:
        plt.figure()
        plt.plot(cropped_z_i,-val)
        plt.title("Kernel")
        plt.xlabel("z")
        plt.ylabel("W(z)")
        # xlim = plt.xlim()
        # plt.xlim([xlim[0],-xlim[0]])
        plt.savefig(system_name+"_kernel.png")

    ########## plot PDF #################
    if plot_PDF:
        fontsize=16
        plt.figure()
        plt.plot(z_i,rho0,color="#1f77b4",alpha=0.5,linestyle="--",label=r"$\tilde{\rho}$")
        plt.plot(z_i,rho,color="#1f77b4",label=r"$\rho$")
        plt.plot(z_i,g0,color="#ff7f0e",label=r"$\bar{\rho}$")
        if x_lim is not None:
            plt.xlim(x_lim)
        ylim = plt.ylim()
        xlim = plt.xlim()
        plt.plot([x0,x0],ylim,alpha=0.3,color='k',linestyle="-.")
        plt.plot([x,x],ylim,alpha=0.5,color='k',linestyle="-.",label="x")
        ax = plt.gca()
        ax.fill_between(z_i,rho0,alpha=0.15,color="#1f77b4")
        ax.fill_between(z_i,rho,alpha=0.3,color="#1f77b4")
        ax.fill_between(z_i,g0,alpha=0.3,color="#ff7f0e")
        y_val = np.mean(ylim)
        # plt.arrow(x0,y_val,x-x0,0,length_includes_head=True,color="k",width=0.0005,head_width=0.01,head_length=0.04)
        plt.arrow(x0,y_val,x-x0,0,length_includes_head=True,color="k",width=0.005,head_width=0.05,head_length=0.2)
        if plot_intermediates:
            # plot values midway
            plt.plot([x_history[0],x_history[0]],ylim,alpha=0.35,color='k',linestyle="-.")
            plt.plot([x_history[1],x_history[1]],ylim,alpha=0.4,color='k',linestyle="-.")
            # plt.plot([x_history[2],x_history[2]],ylim,alpha=0.45,color='k',linestyle="-.")

            for idx_ in range(2):
                plt.plot(z_i,rho_history[:,idx_],color="#1f77b4",alpha=0.5+0.05*idx_,linestyle="--")
                ax.fill_between(z_i,rho_history[:,idx_],alpha=0.15+0.05*idx_,color="#1f77b4")
        
        plt.legend(fontsize=fontsize)
        # for now, plot halfway for both distributions
        # halfway_idx = np.where(np.cumsum(rho)<0.5*np.sum(rho))[-1][-1]
        # plt.plot([z_i[halfway_idx],z_i[halfway_idx]],ylim,color='m')
        plt.ylabel("probability",fontsize=fontsize)
        plt.xlabel("z",fontsize=fontsize)
        ax.tick_params(axis='x', labelsize=fontsize)
        ax.tick_params(axis='y', labelsize=fontsize)
        plt.tight_layout()
        plt.savefig(system_name+".png")
        if show_plots:
            plt.show()

        
    ########## make gif #################
    if make_gif:
        fig = plt.figure()
        ax  = plt.axes(xlim=xlim, ylim=ylim)
        rho_bar_line, = ax.plot([], [],color="#ff7f0e",label=r"$\bar{\rho}$")
        ic,   = ax.plot([], [],color="#1f77b4") # initial 
        line, = ax.plot([], [],color="#1f77b4",label=r"$\rho$")
        x_line,= ax.plot([],[],alpha=0.9,color="k",linestyle="--",label="x")
        x_ic,  = ax.plot([],[],alpha=0.5,color="k",linestyle="--")

        def init():
            ic.set_data([], [])
            line.set_data([], [])
            x_line.set_data([],[])
            rho_bar_line.set_data([],[])
            return ic,line,x_line
        def animate(i):
            ax.collections.clear()
            y = rho_history[:,i]
            rho_bar_line.set_data(z_i,g0)
            ic.set_data(z_i,rho0)
            line.set_data(z_i, y)
            x_ic.set_data([x0,x0],ylim)
            x_line.set_data([x_history[i],x_history[i]],ylim)
            ax.fill_between(z_i,rho0,alpha=0.2,color="#1f77b4")
            ax.fill_between(z_i,g0,alpha=0.3,color="#ff7f0e")
            ax.fill_between(z_i,y,alpha=0.3,color="#1f77b4")
            plt.ylabel("probability")
            plt.xlabel("z")
            plt.xlim([-2.5,5])
            plt.legend(loc=1)
            return ic,line,x_line

        anim = FuncAnimation(fig, animate, init_func=init, frames=int(np.ceil(nT/pde.subsample)), interval=10, blit=True) #interval=20,
        anim.save(system_name+'.gif', writer='pillow')

    ####### plot losses ############
    if plot_loss:
        fix,(ax1,ax2)=plt.subplots(2,1,sharex=True)
        t=dt*np.arange(len(pde.loss))
        ax1.plot(t,pde.loss)
        ax1.set_title("population loss")
        ax2.plot(t,ode.loss)
        ax2.set_title("classifier loss")
        plt.xlabel("time step")
        # plt.ylabel("loss")
        # plt.legend(["z population","x player"])
        plt.title("loss")
        plt.savefig(system_name+"_loss.png")
        # plt.show()
    plt.close('all')

def make_plots_both_moving(pde0,pde1,ode,system_name,plot_kernel=True,plot_PDF=True,plot_loss=True,make_gif=True):
    z_i    = pde0.z_i
    rho0_0 = pde0.rho0
    rho0   = pde0.rho
    rho1_0 = pde1.rho0
    rho1   = pde1.rho
    x      = ode.x
    nT     = pde0.nT
    x0     = ode.x0
    nRho   = pde0.nRho
    dt     = ode.dt

    if make_gif:
        rho_history0 = pde0.rho_history
        rho_history1 = pde1.rho_history
        x_history    = ode.x_history

    ########## plot kernel #################
    cropped_z_i = z_i[z_i<-z_i[0]]
    try:
        W_vec = np.vectorize(pde0.W)
        val = W_vec(cropped_z_i,0,1.)
    except:
        print("in exception")
        val = pde0.W(0,cropped_z_i,1.)
    if plot_kernel:
        plt.figure()
        plt.plot(cropped_z_i,-val)
        plt.title("Kernel")
        plt.xlabel("z")
        plt.ylabel("W(z)")
        # xlim = plt.xlim()
        # plt.xlim([xlim[0],-xlim[0]])
        plt.savefig(system_name+"_kernel.png")

    ########## plot PDF #################
    if plot_PDF:
        plt.figure()
        plt.plot(z_i,rho0_0,color= "#1f77b4",linestyle="--",alpha=0.5)#,label=r"$\rho_0^0$")
        plt.plot(z_i,rho0,color = "#1f77b4",label=r"$\rho^0$")
        plt.plot(z_i,rho1_0,color="#ff7f0e",linestyle="--",alpha=0.5)#,label=r"$\rho_0^1$")
        plt.plot(z_i,rho1,color="#ff7f0e",label=r"$\rho^1$")
        ylim = plt.ylim()
        xlim = plt.xlim()
        plt.plot([x0,x0],ylim,alpha=0.3,color='k',linestyle="-.")#,label=r"$x_0$")
        plt.plot([x,x],ylim,alpha=0.5,color='k',linestyle="-.",label="x")
        # plt.plot(z_i,0.1/(1.+np.exp(-3.*(z_i-x))),alpha=0.5,color="k")
        ax = plt.gca()
        ax.fill_between(z_i,rho0_0,alpha=0.15,color="#1f77b4")
        ax.fill_between(z_i,rho0,alpha=0.35,color="#1f77b4")
        ax.fill_between(z_i,rho1_0,alpha=0.15,color="#ff7f0e")
        ax.fill_between(z_i,rho1,alpha=0.35,color="#ff7f0e")
        y_val = np.mean(ylim)
        plt.arrow(x0,y_val,x-x0,0,length_includes_head=True,color="k",width=0.0005,head_width=0.01,head_length=0.04)
        # plt.arrow(x0,-0.05,x-x0,0,length_includes_head=True,color="k")
        plt.legend()
        # for now, plot halfway for both distributions
        # halfway_idx = np.where(np.cumsum(rho)<0.5*np.sum(rho))[-1][-1]
        # print(z_i,halfway_idx)
        # plt.plot([z_i[halfway_idx],z_i[halfway_idx]],ylim,color='m')
        plt.ylabel("probability")
        plt.xlabel("z")
        plt.tight_layout()
        plt.savefig(system_name+".png")
        # plt.show()

    ########## make gif #################
    if make_gif:
        fig    = plt.figure()
        # legend = plt.legend()
        ax     = plt.axes(xlim=xlim, ylim=ylim)
        ic0,   = ax.plot([], [], linestyle="--",color="#1f77b4")
        ic1,   = ax.plot([], [], linestyle="--",color="#ff7f0e")
        line0, = ax.plot([], [], color="#1f77b4") 
        line1, = ax.plot([], [], color="#ff7f0e") 
        x0obj, = ax.plot([], [], linestyle="-.",color='k',alpha=0.3)
        xobj,  = ax.plot([], [], linestyle="-.",color='k',alpha=0.5)

        def init():
            ic0.set_data([], [])
            ic1.set_data([],[])
            line0.set_data([], [])
            line0.set_label(r"$\rho^0$")
            line1.set_data([],[])
            line1.set_label(r"$\rho^1$")
            x0obj.set_data([],[])
            xobj.set_data([],[])
            xobj.set_label("x")
            # legend = plt.legend()
            # legend.remove()
            legend = plt.legend()
            
            return ic0,ic1,line0,line1,x0obj,xobj,legend

        def animate_mod(i,mod):
            ax.collections.clear()
            # rho0
            y0 = rho_history0[:,i*mod]
            ic0.set_data(z_i,rho0_0)
            line0.set_data(z_i, y0)
            line0.set_label(r"$\rho$")
            ax.fill_between(z_i,rho0_0,alpha=0.15,color="#1f77b4")
            ax.fill_between(z_i,y0,alpha=0.3,color="#1f77b4")
            # rho1
            y1 = rho_history1[:,i*mod]
            ic1.set_data(z_i,rho1_0)
            line1.set_data(z_i, y1)
            line1.set_label(r"$\bar{\rho}$")
            ax.fill_between(z_i,rho1_0,alpha=0.15,color="#ff7f0e")
            ax.fill_between(z_i,y1,alpha=0.3,color="#ff7f0e")
            plt.ylabel("probability")
            plt.xlabel("z")
            # x
            x = x_history[i*mod]
            x0obj.set_data([x0,x0],ylim)
            xobj.set_data([x,x],ylim)
            xobj.set_label("x")
            #legend
            # legend.remove()
            legend = plt.legend()
            return ic0,ic1,line0,line1,x0obj,xobj,legend
            
        
        subsample  = 5 # select every "subsample" number of frames for the git
        animate = lambda ii: animate_mod(ii,subsample)
        anim = FuncAnimation(fig, animate, init_func=init, 
                                frames=np.floor(nT/subsample).astype(int), interval=5, blit=True) #interval=20,
        anim.save(system_name+'.gif', writer='pillow')

    ####### plot losses ############
    if plot_loss:
        t = dt*np.arange(len(pde0.loss))
        fix,(ax1,ax2)=plt.subplots(2,1,sharex=True)
        ax1.plot(t,pde0.loss)
        ax1.plot(t,pde1.loss)
        ax1.legend([r"$\rho^0$",r"$\rho^1$"])
        ax1.set_title("population loss")
        ax2.plot(ode.loss)
        # ax2.set_title("classifier loss")
        plt.xlabel("time step")
        # plt.ylabel("loss")
        # plt.legend(["z population","x player"])
        plt.title("classifier loss")
        plt.savefig(system_name+"_loss.png")
        # plt.show()
    plt.close('all')