import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

def make_plots_both_moving(pde0,pde1,ode,system_name,plot_kernel=True,plot_PDF=True,plot_loss=True,make_gif=True):
    z_i    = pde0.z_i
    rho0_0 = pde0.rho0
    rho0   = pde0.rho
    rho1_0 = pde1.rho0
    rho1   = pde1.rho
    x      = ode.x
    nT     = pde0.nT
    x0     = ode.x0
    nRho   = pde0.nRho
    dt     = ode.dt

    if make_gif:
        rho_history0 = pde0.rho_history
        rho_history1 = pde1.rho_history
        x_history    = ode.x_history

    ########## plot kernel #################
    cropped_z_i = z_i[z_i<-z_i[0]]
    try:
        W_vec = np.vectorize(pde0.W)
        val = W_vec(cropped_z_i,0,1.)
    except:
        print("in exception")
        val = pde0.W(0,cropped_z_i,1.)
    if plot_kernel:
        plt.figure()
        plt.plot(cropped_z_i,-val)
        plt.title("Kernel")
        plt.xlabel("z")
        plt.ylabel("W(z)")
        # xlim = plt.xlim()
        # plt.xlim([xlim[0],-xlim[0]])
        plt.savefig(system_name+"_kernel.png")

    ########## plot PDF #################
    if plot_PDF:
        fontsize = 16
        plt.figure()
        plt.plot(z_i,rho0_0,color= "#1f77b4",linestyle="--",alpha=0.5,label=r"$\tilde{\rho}$")#,label=r"$\rho_0^0$")
        plt.plot(z_i,rho0,color = "#1f77b4",label=r"$\rho$")
        plt.plot(z_i,rho1_0,color="#ff7f0e",linestyle="--",alpha=0.5,label=r"$\tilde{\tau}$")
        plt.plot(z_i,rho1,color="#ff7f0e",label=r"$\tau$")
        plt.xlim([-2.55,5.1])
        ylim = plt.ylim()
        xlim = plt.xlim()
        plt.plot([x0,x0],ylim,alpha=0.3,color='k',linestyle="-.")#,label=r"$x_0$")
        plt.plot([x,x],ylim,alpha=0.5,color='k',linestyle="-.",label="x")
        # plt.plot(z_i,0.1/(1.+np.exp(-3.*(z_i-x))),alpha=0.5,color="k")
        ax = plt.gca()
        ax.fill_between(z_i,rho0_0,alpha=0.15,color="#1f77b4")
        ax.fill_between(z_i,rho0,alpha=0.35,color="#1f77b4")
        ax.fill_between(z_i,rho1_0,alpha=0.15,color="#ff7f0e")
        ax.fill_between(z_i,rho1,alpha=0.35,color="#ff7f0e")
        y_val = np.mean(ylim)
        plt.arrow(x0,y_val,x-x0,0,length_includes_head=True,color="k",width=0.005,head_width=0.05,head_length=0.2)
        # plt.arrow(x0,-0.05,x-x0,0,length_includes_head=True,color="k")
        plt.legend(fontsize=fontsize)
        ax.tick_params(axis='x', labelsize=fontsize)
        ax.tick_params(axis='y', labelsize=fontsize)
        plt.ylabel("probability",fontsize=fontsize)
        plt.xlabel("z",fontsize=fontsize)
        plt.tight_layout()
        plt.savefig(system_name+".png")
        # plt.show()

    ########## make gif #################
    if make_gif:
        fig    = plt.figure()
        # legend = plt.legend()
        ax     = plt.axes(xlim=xlim, ylim=ylim)
        ic0,   = ax.plot([], [], linestyle="--",color="#1f77b4")
        ic1,   = ax.plot([], [], linestyle="--",color="#ff7f0e")
        line0, = ax.plot([], [], color="#1f77b4") 
        line1, = ax.plot([], [], color="#ff7f0e") 
        x0obj, = ax.plot([], [], linestyle="-.",color='k',alpha=0.3)
        xobj,  = ax.plot([], [], linestyle="-.",color='k',alpha=0.5)

        def init():
            ic0.set_data([], [])
            ic1.set_data([],[])
            line0.set_data([], [])
            line0.set_label(r"$\rho^0$")
            line1.set_data([],[])
            line1.set_label(r"$\rho^1$")
            x0obj.set_data([],[])
            xobj.set_data([],[])
            xobj.set_label("x")
            # legend = plt.legend()
            # legend.remove()
            legend = plt.legend()
            
            return ic0,ic1,line0,line1,x0obj,xobj,legend

        def animate_mod(i,mod):
            ax.collections.clear()
            # rho0
            y0 = rho_history0[:,i*mod]
            ic0.set_data(z_i,rho0_0)
            line0.set_data(z_i, y0)
            line0.set_label(r"$\rho^0$")
            ax.fill_between(z_i,rho0_0,alpha=0.15,color="#1f77b4")
            ax.fill_between(z_i,y0,alpha=0.3,color="#1f77b4")
            # rho1
            y1 = rho_history1[:,i*mod]
            ic1.set_data(z_i,rho1_0)
            line1.set_data(z_i, y1)
            line1.set_label(r"$\rho^1$")
            ax.fill_between(z_i,rho1_0,alpha=0.15,color="#ff7f0e")
            ax.fill_between(z_i,y1,alpha=0.3,color="#ff7f0e")
            plt.ylabel("probability")
            plt.xlabel("z")
            # x
            x = x_history[i*mod]
            x0obj.set_data([x0,x0],ylim)
            xobj.set_data([x,x],ylim)
            xobj.set_label("x")
            #legend
            # legend.remove()
            legend = plt.legend()
            return ic0,ic1,line0,line1,x0obj,xobj,legend
            
        
        subsample  = 5 # select every "subsample" number of frames for the git
        animate = lambda ii: animate_mod(ii,subsample)
        anim = FuncAnimation(fig, animate, init_func=init, 
                                frames=np.floor(nT/subsample).astype(int), interval=5, blit=True) #interval=20,
        anim.save(system_name+'.gif', writer='pillow')

    ####### plot losses ############
    if plot_loss:
        fontsize = 16
        t = dt*np.arange(len(pde0.loss))
        fig,(ax1,ax2)=plt.subplots(2,1,sharex=True)
        ax1.plot(t,pde0.loss)
        ax1.plot(t,pde1.loss)
        ax1.legend([r"$\rho$",r"$\gamma$"],fontsize=fontsize,loc='best')
        ax1.set_title("population loss",fontsize=fontsize)
        ax2.plot(t,ode.loss)
        plt.xlabel("time step",fontsize=fontsize)
        ax2.set_title("classifier loss",fontsize=fontsize)
        ax2.tick_params(axis='x', labelsize=fontsize)
        ax2.tick_params(axis='y', labelsize=fontsize)
        ax1.tick_params(axis='y', labelsize=fontsize)
        plt.tight_layout()
        plt.savefig(system_name+"_loss.png")
        # plt.show()
    plt.close('all')