import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

def make_plots(pde,ode,system_name,plot_PDF=True,plot_loss=True,make_gif=True):
    z_i  = pde.z_i
    rho0 = pde.rho0
    rho  = pde.rho
    x    = ode.x
    xa  = x[:2].reshape((-1,1))
    xb  = x[2:].reshape((-1,1))
    nT   = pde.nT
    g0   = pde.g0
    x0   = ode.x0
    dt   = ode.dt
    rx   = pde.rx
    ry   = pde.ry
    rho_history = pde.rho_history
    x_history   = ode.x_history

    ########## plot PDF #################
    if plot_PDF:
        fontsize =14
        fig = plt.figure()
        ax = fig.add_subplot(2,2,1,projection='3d')
        ax.plot_surface(pde.z_x,pde.z_y,rho0,cmap="jet")
        ax.set_title(r"$\rho_0$",fontsize=fontsize+2)
        ax.set_xlabel(r"$z_1$",fontsize=fontsize)
        ax.set_ylabel(r"$z_2$",fontsize=fontsize)
        ax.tick_params(axis='x', labelsize=fontsize)
        ax.tick_params(axis='y', labelsize=fontsize)
        ax.tick_params(axis='z', labelsize=fontsize)

        ax = fig.add_subplot(2,2,2,projection='3d')
        ax.plot_surface(pde.z_x,pde.z_y,g0,cmap="jet")
        ax.set_title(r"$\bar{\rho_0}$",fontsize=fontsize+2)
        ax.set_xlabel(r"$z_1$",fontsize=fontsize)
        ax.set_ylabel(r"$z_2$",fontsize=fontsize)
        ax.tick_params(axis='x', labelsize=fontsize)
        ax.tick_params(axis='y', labelsize=fontsize)
        ax.tick_params(axis='z', labelsize=fontsize)

        ax = fig.add_subplot(2,2,3,projection='3d')
        ax.plot_surface(pde.z_x,pde.z_y,rho,cmap="jet")
        ax.set_title(r"$\rho_t$",fontsize=fontsize+2)
        ax.set_xlabel(r"$z_1$",fontsize=fontsize)
        ax.set_ylabel(r"$z_2$",fontsize=fontsize)
        ax.tick_params(axis='x', labelsize=fontsize)
        ax.tick_params(axis='y', labelsize=fontsize)
        ax.tick_params(axis='z', labelsize=fontsize)

        # compute x manifold
        prob = ode.prob
        ax = fig.add_subplot(2,2,4,projection='3d')
        ax.plot_surface(pde.z_x,pde.z_y,prob,cmap="jet")
        ax.contour(pde.z_x,pde.z_y,prob, zdir='x', offset=-5, cmap='jet',alpha=0.5)
        ax.contour(pde.z_x,pde.z_y,prob, zdir='y', offset=5, cmap='jet',alpha=0.5)
        ax.set_xlabel(r"$z_1$",fontsize=fontsize)
        ax.set_ylabel(r"$z_2$",fontsize=fontsize)
        ax.tick_params(axis='x', labelsize=fontsize)
        ax.tick_params(axis='y', labelsize=fontsize)
        ax.tick_params(axis='z', labelsize=fontsize)
        ax.set_zlim([-0.1,1.1])
        ax.set_title("probability",fontsize=fontsize+2,pad=0)
        
        # fig.subplots_adjust(hspace=0.35)
        plt.tight_layout()
        plt.savefig(system_name+"3d.png")
        
        # 2d version
        plot_bounds  = [rx.min(),rx.max(),ry.max(),ry.min()]
        color_min    = rho_history.min()
        color_max    = rho_history.max()
        fig = plt.figure()
        ax = fig.add_subplot(2,2,1)
        im = ax.imshow(rho0,cmap="jet",extent=plot_bounds,vmin=color_min,vmax=color_max)
        ax.set_xlabel(r"$z_1$",fontsize=fontsize)
        ax.set_ylabel(r"$z_2$",fontsize=fontsize)
        ax.tick_params(axis='x', labelsize=fontsize)
        ax.tick_params(axis='y', labelsize=fontsize)
        ax.set_ylim([plot_bounds[2],plot_bounds[3]])
        cb = fig.colorbar(im)
        cb.ax.tick_params(labelsize=fontsize)
        ax.set_title(r"$\rho_0$",fontsize=fontsize+2)

        ax = fig.add_subplot(2,2,2)
        im = ax.imshow(g0,cmap="jet",extent=plot_bounds,vmin=color_min,vmax=color_max)
        ax.set_xlabel(r"$z_1$",fontsize=fontsize)
        ax.set_ylabel(r"$z_2$",fontsize=fontsize)
        ax.set_ylim([plot_bounds[2],plot_bounds[3]])
        ax.tick_params(axis='x', labelsize=fontsize)
        ax.tick_params(axis='y', labelsize=fontsize)
        cb = fig.colorbar(im)
        cb.ax.tick_params(labelsize=fontsize)
        ax.set_title(r"$\bar{\rho}$",fontsize=fontsize+2)

        ax = fig.add_subplot(2,2,3)
        im=ax.imshow(rho,cmap="jet",extent=plot_bounds,vmin=color_min,vmax=color_max)
        cb = fig.colorbar(im)
        cb.ax.tick_params(labelsize=fontsize)
        ax.set_xlabel(r"$z_1$",fontsize=fontsize)
        ax.set_ylabel(r"$z_2$",fontsize=fontsize)
        ax.set_ylim([plot_bounds[2],plot_bounds[3]])
        ax.tick_params(axis='x', labelsize=fontsize)
        ax.tick_params(axis='y', labelsize=fontsize)
        ax.set_title(r"$\rho_t$",fontsize=fontsize+2)

        # compute x manifold
        ax = fig.add_subplot(2,2,4)
        im=ax.imshow(prob,cmap="jet",extent=[pde.rx[0],pde.rx[-1],pde.ry[0],pde.ry[-1]])
        cb = fig.colorbar(im)
        cb.ax.tick_params(labelsize=fontsize)
        ax.set_xlabel(r"$z_1$",fontsize=fontsize)
        ax.set_ylabel(r"$z_2$",fontsize=fontsize)
        ax.set_title("probability",fontsize=fontsize+2)
        ax.tick_params(axis='x', labelsize=fontsize)
        ax.tick_params(axis='y', labelsize=fontsize)
        plt.tight_layout()
        plt.savefig(system_name+"2d.png")

    ########## make gif #################
    if make_gif:
        fig = plt.figure()
        ax  = plt.axes()
        temp = rho_history[:,:,0]
        im   = ax.imshow(temp,cmap='jet',extent=[rx.min(),rx.max(),ry.max(),ry.min()],vmin=color_min,vmax=color_max) # initial 
        line, = ax.plot([], [],lw=2,color='w')

        def init():
            im.set_data([[]])
            line.set_data([], [])
            return im,line

        def animate(i):
            ax.collections.clear()
            y = rho_history[:,:,i]
            x = x_history[:,i]
            im.set_data(y)
            y_data = -x[0]/x[1]*(rx-x[2])+x[3]
            line.set_data(rx,y_data)
            return im,line

        anim = FuncAnimation(fig, animate, init_func=init, frames=nT, interval=5, blit=True) #interval=20,
        anim.save(system_name+'.gif', writer='pillow')

    ####### plot losses ############
    if plot_loss:
        fix,(ax1,ax2)=plt.subplots(2,1,sharex=True)
        t=dt*np.arange(len(pde.loss))
        ax1.plot(t,pde.loss)
        ax1.set_title("population loss")
        ax2.plot(t,ode.loss)
        ax2.set_title("classifier loss")
        plt.xlabel("time step")
        # plt.ylabel("loss")
        # plt.legend(["z population","x player"])
        plt.title("loss")
        plt.savefig(system_name+"_loss.png")
        # plt.show()
    plt.close('all')

def make_plots_both_moving(pde0,pde1,ode,system_name,plot_kernel=True,plot_PDF=True,plot_loss=True,make_gif=True):
    z_i    = pde0.z_i
    rho0_0 = pde0.rho0
    rho0   = pde0.rho
    rho1_0 = pde1.rho0
    rho1   = pde1.rho
    x      = ode.x
    nT     = pde0.nT
    x0     = ode.x0
    nRho   = pde0.nRho
    dt     = ode.dt

    if make_gif:
        rho_history0 = pde0.rho_history
        rho_history1 = pde1.rho_history
        x_history    = ode.x_history

    ########## plot kernel #################
    cropped_z_i = z_i[z_i<-z_i[0]]
    try:
        W_vec = np.vectorize(pde0.W)
        val = W_vec(cropped_z_i,0,1.)
    except:
        print("in exception")
        val = pde0.W(0,cropped_z_i,1.)
    if plot_kernel:
        plt.figure()
        plt.plot(cropped_z_i,-val)
        plt.title("Kernel")
        plt.xlabel("z")
        plt.ylabel("W(z)")
        # xlim = plt.xlim()
        # plt.xlim([xlim[0],-xlim[0]])
        plt.savefig(system_name+"_kernel.png")

    ########## plot PDF #################
    if plot_PDF:
        plt.figure()
        plt.plot(z_i,rho0_0,color= "#1f77b4",linestyle="--",alpha=0.5)#,label=r"$\rho_0^0$")
        plt.plot(z_i,rho0,color = "#1f77b4",label=r"$\rho^0$")
        plt.plot(z_i,rho1_0,color="#ff7f0e",linestyle="--",alpha=0.5)#,label=r"$\rho_0^1$")
        plt.plot(z_i,rho1,color="#ff7f0e",label=r"$\rho^1$")
        ylim = plt.ylim()
        xlim = plt.xlim()
        plt.plot([x0,x0],ylim,alpha=0.3,color='k',linestyle="-.")#,label=r"$x_0$")
        plt.plot([x,x],ylim,alpha=0.5,color='k',linestyle="-.",label="x")
        plt.plot(z_i,0.1/(1.+np.exp(-3.*(z_i-x))),alpha=0.5,color="k")
        ax = plt.gca()
        ax.fill_between(z_i,rho0_0,alpha=0.15,color="#1f77b4")
        ax.fill_between(z_i,rho0,alpha=0.35,color="#1f77b4")
        ax.fill_between(z_i,rho1_0,alpha=0.15,color="#ff7f0e")
        ax.fill_between(z_i,rho1,alpha=0.35,color="#ff7f0e")
        y_val = np.mean(ylim)
        plt.arrow(x0,y_val,x-x0,0,length_includes_head=True,color="k",width=0.0005,head_width=0.01,head_length=0.04)
        # plt.arrow(x0,-0.05,x-x0,0,length_includes_head=True,color="k")
        plt.legend()
        # for now, plot halfway for both distributions
        # halfway_idx = np.where(np.cumsum(rho)<0.5*np.sum(rho))[-1][-1]
        # print(z_i,halfway_idx)
        # plt.plot([z_i[halfway_idx],z_i[halfway_idx]],ylim,color='m')
        plt.ylabel("probability")
        plt.xlabel("z")
        plt.tight_layout()
        plt.savefig(system_name+".png")
        # plt.show()

    ########## make gif #################
    if make_gif:
        fig    = plt.figure()
        # legend = plt.legend()
        ax     = plt.axes(xlim=xlim, ylim=ylim)
        ic0,   = ax.plot([], [], linestyle="--",color="#1f77b4")
        ic1,   = ax.plot([], [], linestyle="--",color="#ff7f0e")
        line0, = ax.plot([], [], color="#1f77b4") 
        line1, = ax.plot([], [], color="#ff7f0e") 
        x0obj, = ax.plot([], [], linestyle="-.",color='k',alpha=0.3)
        xobj,  = ax.plot([], [], linestyle="-.",color='k',alpha=0.5)

        def init():
            ic0.set_data([], [])
            ic1.set_data([],[])
            line0.set_data([], [])
            line0.set_label(r"$\rho^0$")
            line1.set_data([],[])
            line1.set_label(r"$\rho^1$")
            x0obj.set_data([],[])
            xobj.set_data([],[])
            xobj.set_label("x")
            # legend = plt.legend()
            # legend.remove()
            legend = plt.legend()
            
            return ic0,ic1,line0,line1,x0obj,xobj,legend

        def animate_mod(i,mod):
            ax.collections.clear()
            # rho0
            y0 = rho_history0[:,i*mod]
            ic0.set_data(z_i,rho0_0)
            line0.set_data(z_i, y0)
            line0.set_label(r"$\rho^0$")
            ax.fill_between(z_i,rho0_0,alpha=0.15,color="#1f77b4")
            ax.fill_between(z_i,y0,alpha=0.3,color="#1f77b4")
            # rho1
            y1 = rho_history1[:,i*mod]
            ic1.set_data(z_i,rho1_0)
            line1.set_data(z_i, y1)
            line1.set_label(r"$\rho^1$")
            ax.fill_between(z_i,rho1_0,alpha=0.15,color="#ff7f0e")
            ax.fill_between(z_i,y1,alpha=0.3,color="#ff7f0e")
            plt.ylabel("probability")
            plt.xlabel("z")
            # x
            x = x_history[i*mod]
            x0obj.set_data([x0,x0],ylim)
            xobj.set_data([x,x],ylim)
            xobj.set_label("x")
            #legend
            # legend.remove()
            legend = plt.legend()
            return ic0,ic1,line0,line1,x0obj,xobj,legend
            
        
        subsample  = 5 # select every "subsample" number of frames for the git
        animate = lambda ii: animate_mod(ii,subsample)
        anim = FuncAnimation(fig, animate, init_func=init, 
                                frames=np.floor(nT/subsample).astype(int), interval=5, blit=True) #interval=20,
        anim.save(system_name+'.gif', writer='pillow')

    ####### plot losses ############
    if plot_loss:
        t = dt*np.arange(len(pde0.loss))
        fix,(ax1,ax2)=plt.subplots(2,1,sharex=True)
        ax1.plot(t,pde0.loss)
        ax1.plot(t,pde1.loss)
        ax1.legend([r"$\rho^0$",r"$\rho^1$"])
        ax1.set_title("population loss")
        ax2.plot(ode.loss)
        # ax2.set_title("classifier loss")
        plt.xlabel("time step")
        # plt.ylabel("loss")
        # plt.legend(["z population","x player"])
        plt.title("classifier loss")
        plt.savefig(system_name+"_loss.png")
        # plt.show()
    plt.close('all')