"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

import numpy as np
import jax.numpy as jnp
from jax import vmap
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from scipy.interpolate import griddata as int_grid
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.cm import get_cmap
from matplotlib.colors import ListedColormap
import healpy as hp
from healpy.newvisufunc import projview, newprojplot
from utils import sin_thres
import trimesh
from vedo import trimesh2vedo
import vedo
from vedo import *
#plotting function to generate the figures for the Tori problem
def plotVelDensTori(u,rho,T=[0,0.15,0.3],apx="",colorbar=True):
    num = len(T)
    fig,ax = plt.subplots(2,num,figsize=(21,14))
    
    plotVelsTori(u,T,ax[0])
    plotDensTori(rho,T,ax[1])
    
    fig.savefig("plots/2d_tori_{}.png".format(apx))
    plt.close(fig)

    
def plotVelsTori(u,T,ax1,ax2=None,rf_sols=None):
    N = 250
    a=1
    X,Y = np.meshgrid(np.linspace(0,a,N),np.linspace(0,a,N))
    
    for i,t in enumerate(T):
        pts = jnp.vstack([np.ones(X.reshape(-1).shape)*t,X.reshape(-1),Y.reshape(-1)]).T
        vel = vmap(u)(pts)
        U = np.array(vel[:,0].reshape(X.shape))
        V = np.array(vel[:,1].reshape(Y.shape))
        ax1[i].set_xlim(0,a)
        ax1[i].set_ylim(0,a)
        plt_str = ax1[i].streamplot(X,Y,U,V,density=0.45,arrowsize=4,linewidth=4,color='k')
        ax1[i].axis('off')
        
        rect = patches.Rectangle((0, 0), 1, 1, linewidth=1, edgecolor='k', facecolor='none')
        ax1[i].add_patch(rect)
        
        if ax2 is not None:
            rf = rf_sols[i]
            points = np.array(rf.points[:,:2])
            rho_ref = np.array(rf['rho_n'])
            u_ref = np.array(rf['u_n'])

            rho_ref = int_grid(points,rho_ref,(X,Y))
            u_ref = int_grid(points,u_ref,(X,Y))

            ax2[i].set_xlim(0,a)
            ax2[i].set_ylim(0,a)
            #ax[0].plot(bd[0],bd[1])
            U_1 = np.array(u_ref[:,:,0])
            V_1 = np.array(u_ref[:,:,1])
            plt_str = ax2[i].streamplot(X,Y,U_1,V_1,density=0.45,arrowsize=4,linewidth=4,color='k')
            ax2[i].axis('off')

            rect = patches.Rectangle((0, 0), 1, 1, linewidth=1, edgecolor='k', facecolor='none')
            ax2[i].add_patch(rect)

def plotDensTori(rho,T,ax1,ax2=None,rf_sols=None,colorbar=True,clim=None):
    N = 250
    a=1
    X,Y = np.meshgrid(np.linspace(0,a,N),np.linspace(0,a,N))

    for i,t in enumerate(T):
        pts = jnp.vstack([np.ones(X.reshape(-1).shape)*t,X.reshape(-1),Y.reshape(-1)]).T
        dens = vmap(rho)(pts).reshape(X.shape)
        dens = np.array(dens)
        
        ax1[i].set_xlim(0,a)
        ax1[i].set_ylim(0,a)
        if clim:
            plt_dens1 = ax1[i].contourf(X,Y,dens,150,vmin=clim[0],vmax=clim[1])
        else:
            plt_dens1 = ax1[i].contourf(X,Y,dens,150)
        ax1[i].axis('off')
        
        if ax2 is not None:
            rf = rf_sols[i]
            points = np.array(rf.points[:,:2])
            rho_ref = np.array(rf['rho_n'])
            rho_ref = int_grid(points,rho_ref,(X,Y))

            ax2[i].set_xlim(0,a)
            ax2[i].set_ylim(0,a)
            #ax[0].plot(bd[0],bd[1])
            plt_dens2 = ax2[i].contourf(X,Y,rho_ref,150,vmax=clim[1],vmin=clim[0])
            ax2[i].axis('off')
    
    if colorbar:
        divider1 = make_axes_locatable(ax1[-1])
        cax1 = divider1.append_axes("right", size="5%", pad=0.10)
        plt.colorbar(plt_dens1,cax=cax1)

def plotStats(stats,apx):
    fig,ax = plt.subplots(1,1,figsize=(10,5))
    
    ax.plot(stats)
    ax.set_yscale('log')
    ax.set_ylabel('loss')
    ax.set_xlabel('steps (x100)')
    fig.savefig("plots/{}.png".format(apx))
    plt.close(fig)
    
#plotting function to generate the figures for the ball problem
def plotVelDensBall(u,rho,T=[0,0.25,0.5],apx=""):
    box= 8
    #our plots
    fig1,ax1 = plt.subplots(1,3,figsize=(3*box,box))
    fig2,ax2 = plt.subplots(1,3,figsize=(3*box,box))
    
    for i,t in enumerate(T): 
        plotDensBall(t,rho,Z=0,ax=ax1[i])
        plotVelBall(t,u,Z=0,ax=ax2[i])
    
    fig1.tight_layout()
    fig2.tight_layout()
    fig1.savefig("plots/3d_slice_densplot_ours_{}.png".format(apx))
    fig2.savefig("plots/3d_slice_streamplot_ours_{}.png".format(apx))
    

def heal_sample(nside,T):
    npix = hp.nside2npix(nside)
    ag_array = jnp.array(np.stack(hp.pix2ang(nside,np.arange(npix),nest=True),axis=-1))
    #print(ag_array.shape)
    #ag_array = ag_array.at[:,1].set(ag_array[:,1] - jnp.pi)
    #print(T)
    t = (jnp.ones(ag_array.shape[0])*T).reshape(-1,1)

    u_0 = jnp.sin(ag_array[:,0])*jnp.sin(ag_array[:,1])
    u_1 = jnp.sin(ag_array[:,0])*jnp.cos(ag_array[:,1])
    u_2 = jnp.cos(ag_array[:,0])
    pts = jnp.stack([u_0,u_1,u_2],axis=1)
    return jnp.concatenate([t,pts],axis=1)

def plotDensSphere(T,rho,w_,fig,fig1,fig2,nside):
    pts = heal_sample(nside,T)

    w = vmap(w_)(pts)
    w = np.array(w)

    density = vmap(rho)(pts)
    density = np.array(density)
    print(density.shape)
    print(density.shape,np.max(density),np.min(density))
    k = np.median(density)

    hp.mollview(density,fig,rot = (180,0,0),nest=True)
    hp.mollview(density,fig1,rot = (0,0,0),nest=True)
    hp.mollview(w,fig2,rot = (0,0,0),nest=True)
    # projview(density, coord=["G"], flip="astro", projection_type="mollweide");

def plotVelsSphere(T,u,fig,nside):
    N = 250
    a = 1.1

    pts = heal_sample(nside,T)
    vel = vmap(u)(pts)
    cmap_colors = get_cmap('binary', 256)(np.linspace(0, 1, 256))
    cmap_colors[..., 3] = 0.4  # Make colormap partially transparent
    cmap = ListedColormap(cmap_colors)
    ### non be right here for visulaization/ need to test
    #Q = np.array(vel[:,0]).astype(np.float64)
    #U = np.array(vel[:,1]).astype(np.float64)

    #I, Q, U = hp.smoothing(vel, np.deg2rad(5))
    Q = hp.smoothing(vel[:,0], np.deg2rad(5),nest=True)
    U = hp.smoothing(vel[:,1], np.deg2rad(5),nest=True)
    lic = hp.line_integral_convolution(Q, U)
    lic = hp.smoothing(lic, np.deg2rad(0.5))

    hp.mollview(np.log(1 + np.sqrt(Q**2 + U**2) * 100),fig, cmap='inferno', cbar=False)
    hp.mollview(lic,fig, cmap=cmap, cbar=False, reuse_axes=True, title='WMAP K')


def plotVelDenSphere(u,time,nside=128,apx=""):
    box= 8
    #our plots
    fig1,_ = plt.subplots(1,3,figsize=(3*box,box))
    fig2,_ = plt.subplots(1,3,figsize=(3*box,box))
    fig3,_ = plt.subplots(1,3,figsize=(3*box,box))

    mag = lambda x: jnp.linalg.norm(u(x[1:])[:-1])
    w = lambda x: u(x[1:])[-1]
    plotDensSphere(time,mag,w,fig1,fig2,fig3,nside)
    #plotVelsSphere(time,uk,fig2,nside)
    
    fig1.tight_layout()
    fig2.tight_layout()
    fig1.savefig("plots/densplot_{}.png".format(apx))
    fig2.savefig("plots/densplot_inv_{}.png".format(apx))
    fig3.savefig("plots/w_{}.png".format(apx))
    

def planar_sample(nside,T):
    N = nside
    x, y = np.meshgrid(np.arange(N), np.arange(N))
    x = (x/N-0.5)*2*jnp.pi
    y = (y/N-0.5)*2*jnp.pi
    z = np.zeros_like(x)
    meshgrid = np.stack((x, y, z), axis=-1)
    pts = meshgrid.reshape([-1,3])

    n = jnp.array([0.3,-0.5,0.8])
    n = n/jnp.linalg.norm(n)
    rotation_m = jnp.array([[n[1]/jnp.sqrt(n[0]**2+n[1]**2),-n[0]/jnp.sqrt(n[0]**2+n[1]**2),0],
                            [n[0]*n[2]/jnp.sqrt(n[0]**2+n[1]**2),n[1]*n[2]/jnp.sqrt(n[0]**2+n[1]**2),-jnp.sqrt(n[0]**2+n[1]**2)],
                            [n[0],n[1],n[2]]
                            ])
    pts = (rotation_m@pts.transpose()).transpose()

    t = (jnp.ones(pts.shape[0])*T).reshape(-1,1)
    return jnp.concatenate([t,pts],axis=1)

def plotDensPlanar(T,rho,w_,fig,fig1,fig2,nside):
    # pts = heal_sample(nside,T)
    pts = planar_sample(nside,T)
    w = vmap(w_)(pts)
    w = np.array(w)

    density = vmap(rho)(pts)
    density = np.array(density)
    print(density.shape)
    print(density.shape,np.max(density),np.min(density))
    k = np.median(density)
    density = density.reshape((nside,nside))
    w = w.reshape((nside,nside))
    cax1 = fig[0].imshow(density)
    #cax2 = fig1.matshow(arr)
    cax3 = fig2[0].imshow(w)
    # projview(density, coord=["G"], flip="astro", projection_type="mollweide");


def plotVelDenplanar(u,time,nside=512,apx=""):
    box= 8
    #our plots
    fig1,ax1 = plt.subplots(1,1,figsize=(3*box,box))
    fig2,ax2 = plt.subplots(1,1,figsize=(3*box,box))
    fig3,ax3 = plt.subplots(1,1,figsize=(3*box,box))

    mag = lambda x: jnp.linalg.norm(u(x[1:])[:-1])
    w = lambda x: u(x[1:])[-1]
    plotDensPlanar(time,mag,w,ax1,ax2,ax3,nside)
    #plotVelsSphere(time,uk,fig2,nside)
    
    fig1.tight_layout()
    fig3.tight_layout()
    fig1.savefig("plots/densplot_{}.png".format(apx))
    # fig2.savefig("plots/densplot_inv_{}.png".format(apx))
    fig3.savefig("plots/w_{}.png".format(apx))

def plot_exp_mesh(u,time,v,v_normal,faces,nside=512,apx=""):
    box= 8
    #our plots
    # fig1,ax1 = plt.subplots(1,1,figsize=(3*box,box))
    # #fig2,ax2 = plt.subplots(1,1,figsize=(3*box,box))
    # fig3,ax3 = plt.subplots(1,1,figsize=(3*box,box))

    # fig1 = plt.figure()
    # ax1 = fig1.add_subplot(projection='3d')
    # fig3 = plt.figure()
    # ax3 = fig3.add_subplot(projection='3d')
    
    mag = lambda x: jnp.linalg.norm(u(x)[:-1])
    w = lambda x: u(x)[-1]
    pts = jnp.concatenate([v,v_normal],axis=-1)
    w = vmap(w)(pts)
    w = np.array(w)
    density = vmap(mag)(pts)
    density = np.array(density)

    plt_ = vedo.Plotter(offscreen=True) 
    mesh = trimesh.Trimesh(vertices=v,faces=faces)
    vmeshes = trimesh2vedo(mesh)
    vmeshes.cmap('jet',density)
    camera = plt_.camera
    camera.Azimuth(260)
    plt_.show(vmeshes)
    plt_.screenshot("plots/densplot1_{}.png".format(apx))
    camera.Azimuth(180)
    plt_.screenshot("plots/densplot2_{}.png".format(apx))

    plt_ = vedo.Plotter(offscreen=True)
    vmeshes = trimesh2vedo(mesh)
    vmeshes.cmap('jet',w)
    camera_2 = plt_.camera
    camera_2.Azimuth(260)
    plt_.show(vmeshes)
    plt_.screenshot("plots/w1_{}.png".format(apx))
    camera_2.Azimuth(180)
    plt_.screenshot("plots/w2_{}.png".format(apx))


    # plt_ = vedo.Plotter(offscreen=True) 
    # mesh = trimesh.Trimesh(vertices=v,faces=faces)
    # vmeshes = trimesh2vedo(mesh)
    # vmeshes.cmap('jet',density)
    # plt_.show(vmeshes)
    # plt_.screenshot("plots/densplot1_{}.png".format(apx))


    # plt_ = vedo.Plotter(offscreen=True)
    # vmeshes = trimesh2vedo(mesh)
    # vmeshes.cmap('jet',w)
    # plt_.show(vmeshes)
    # plt_.screenshot("plots/w1_{}.png".format(apx))



    # face_colors = np.mean(density[faces], axis=1)

    # ax1.plot_trisurf(v[:,0],v[:,1],v[:,2],triangles = faces, color=plt.cm.viridis(face_colors), shade=True, linewidth=0, antialiased=True)
    
    # face_colors_2 = np.mean(w[faces], axis=1)
    # ax3.plot_trisurf(v[:,0],v[:,1],v[:,2],triangles = faces, color=plt.cm.viridis(face_colors_2), shade=True, linewidth=0, antialiased=True)

    # fig1.tight_layout()
    # fig3.tight_layout()
    # fig1.savefig("plots/densplot_{}.png".format(apx))
    # fig3.savefig("plots/w_{}.png".format(apx))


def plotVelBall(T,u,ax,Z=0):
    N = 250
    a = 1.1

    X,Y = np.meshgrid(np.linspace(-a,a,N),np.linspace(-a,a,N))
    exterior = X**2 + Y**2 + Z**2 >= 1
    pts = jnp.vstack([np.ones(X.reshape(-1).shape)*T,X.reshape(-1),Y.reshape(-1),np.ones(X.reshape(-1).shape)*Z]).T

    #plots the streamplot for the velocity field
    if ax is None:
        fig,ax = plt.subplots(1,2,figsize=(14,7))
    ax.set_xlim(-a,a)
    ax.set_ylim(-a,a)
    
    vel = vmap(u)(pts)
    U = np.array(vel[:,0].reshape(X.shape))
    V = np.array(vel[:,1].reshape(Y.shape))
    #mask the outside of the ball
    U[exterior] = np.nan
    V[exterior] = np.nan
    plt_str = ax.streamplot(X,Y,U,V,density=0.35,color=U**2 + V**2, arrowsize=5,linewidth=3)
    
    #add outline for aesthetics
    circle = plt.Circle((0, 0), 1.05, fill=False, lw=3,color='k')
    ax.add_patch(circle)
    ax.axis('off')
    

def plotDensBall(T,rho,ax,Z=0):
    N = 250
    a = 1.1

    X,Y = np.meshgrid(np.linspace(-a,a,N),np.linspace(-a,a,N))
    exterior = X**2 + Y**2 + Z**2 >= 1
    pts = jnp.vstack([np.ones(X.reshape(-1).shape)*T,X.reshape(-1),Y.reshape(-1),np.ones(X.reshape(-1).shape)*Z]).T

    density = vmap(rho)(pts).reshape(X.shape)
    density = np.array(density)
    density[exterior] = np.nan
    plt_dens = ax.contourf(X,Y,density,20)
    circle = plt.Circle((0, 0), 1.0, fill=False, lw=3,color='k')
    ax.add_patch(circle)
    
    ax.set_xlim(-a,a)
    ax.set_ylim(-a,a)
    
    ax.axis('off')


def plot_impicit_mesh(u,time,sampler,nside=512,apx=""):
    box= 8
    #our plots
    # fig1,ax1 = plt.subplots(1,1,figsize=(3*box,box))
    # #fig2,ax2 = plt.subplots(1,1,figsize=(3*box,box))
    # fig3,ax3 = plt.subplots(1,1,figsize=(3*box,box))

    # fig1 = plt.figure()
    # ax1 = fig1.add_subplot(projection='3d')
    # fig3 = plt.figure()
    # ax3 = fig3.add_subplot(projection='3d')
    v = sampler.v
    print(v.shape)
    v_normal = sampler.normals
    faces = sampler.f
    mag = lambda x: jnp.linalg.norm(u(x)[:-1])
    w = lambda x: u(x)[-1]
    pts = jnp.concatenate([v,v_normal],axis=-1)
    
    func_w = vmap(w)
    func_mag = vmap(mag)
    n_tw = pts.shape[0]
    w_ = []
    step = n_tw//16
    # for i in range(128):
    #     sdf.append(vmap(func_mlp)(samples[i*step:(i+1)*step]))
    density_ = []
    for i in range(16):
        w = func_w(pts[i*step:(i+1)*step])
        w_.append(np.array(w))
        density = func_mag(pts[i*step:(i+1)*step])
        density_.append(np.array(density))
    w_.append(func_w(pts[16*step:]))
    density_.append(func_mag(pts[16*step:]))
    w = np.concatenate(w_,axis=0)
    density = np.concatenate(density_,axis=0)

    # plt_ = vedo.Plotter(offscreen=True) 
    # mesh = trimesh.Trimesh(vertices=v,faces=faces)
    # vmeshes = trimesh2vedo(mesh)
    # vmeshes.cmap('jet',density)
    # camera = plt_.camera
    # camera.Azimuth(260)
    # plt_.show(vmeshes)
    # plt_.screenshot("plots/densplot1_{}.png".format(apx))
    # camera.Azimuth(180)
    # plt_.screenshot("plots/densplot2_{}.png".format(apx))

    # plt_ = vedo.Plotter(offscreen=True)
    # vmeshes = trimesh2vedo(mesh)
    # vmeshes.cmap('jet',w)
    # camera_2 = plt_.camera
    # camera_2.Azimuth(260)
    # plt_.show(vmeshes)
    # plt_.screenshot("plots/w1_{}.png".format(apx))
    # camera_2.Azimuth(180)
    # plt_.screenshot("plots/w2_{}.png".format(apx))
    

    plt_ = vedo.Plotter(offscreen=True) 
    mesh = trimesh.Trimesh(vertices=v,faces=faces,process=False)
    vmeshes = trimesh2vedo(mesh)
    vmeshes.cmap('jet',density)
    camera = plt_.camera
    #camera.Azimuth(180)
    vmeshes.add_scalarbar(title="my bar") 
    plt_.show(vmeshes)
    plt_.screenshot("plots/densplot1_{}.png".format(apx))

    plt_ = vedo.Plotter(offscreen=True) 
    vmeshes = trimesh2vedo(mesh)
    vmeshes.cmap('jet',density)
    camera = plt_.camera
    camera.Elevation(120)
    vmeshes.add_scalarbar(title="my bar") 
    plt_.show(vmeshes)
    plt_.screenshot("plots/densplot2_{}.png".format(apx))

    plt_ = vedo.Plotter(offscreen=True) 
    vmeshes = trimesh2vedo(mesh)
    vmeshes.cmap('jet',density)
    camera = plt_.camera
    camera.Elevation(-120)
    vmeshes.add_scalarbar(title="my bar") 
    plt_.show(vmeshes)
    plt_.screenshot("plots/densplot3_{}.png".format(apx))

    plt_ = vedo.Plotter(offscreen=True) 
    vmeshes = trimesh2vedo(mesh)
    vmeshes.cmap('jet',density)
    camera = plt_.camera
    vmeshes.add_scalarbar(title="my bar") 
    plt_.show(vmeshes)
    plt_.screenshot("plots/densplot4_{}.png".format(apx))

    plt_ = vedo.Plotter(offscreen=True)
    vmeshes = trimesh2vedo(mesh)
    vmeshes.cmap('jet',w)
    camera = plt_.camera
    #camera.Azimuth(180)
    vmeshes.add_scalarbar(title="my bar")
    plt_.show(vmeshes)
    plt_.screenshot("plots/w1_{}.png".format(apx))

    plt_ = vedo.Plotter(offscreen=True) 
    vmeshes = trimesh2vedo(mesh)
    vmeshes.cmap('jet',w)
    camera = plt_.camera
    camera.Elevation(120)
    vmeshes.add_scalarbar(title="my bar") 
    plt_.show(vmeshes)
    plt_.screenshot("plots/w2_{}.png".format(apx))

    plt_ = vedo.Plotter(offscreen=True) 
    vmeshes = trimesh2vedo(mesh)
    vmeshes.cmap('jet',w)
    camera = plt_.camera
    camera.Elevation(-120)
    vmeshes.add_scalarbar(title="my bar") 
    plt_.show(vmeshes)
    plt_.screenshot("plots/w3_{}.png".format(apx))

    plt_ = vedo.Plotter(offscreen=True) 
    vmeshes = trimesh2vedo(mesh)
    vmeshes.cmap('jet',w)
    camera = plt_.camera
    vmeshes.add_scalarbar(title="my bar") 
    plt_.show(vmeshes)
    plt_.screenshot("plots/w4_{}.png".format(apx))


    # face_colors = np.mean(density[faces], axis=1)

    # ax1.plot_trisurf(v[:,0],v[:,1],v[:,2],triangles = faces, color=plt.cm.viridis(face_colors), shade=True, linewidth=0, antialiased=True)
    
    # face_colors_2 = np.mean(w[faces], axis=1)
    # ax3.plot_trisurf(v[:,0],v[:,1],v[:,2],triangles = faces, color=plt.cm.viridis(face_colors_2), shade=True, linewidth=0, antialiased=True)

    # fig1.tight_layout()
    # fig3.tight_layout()
    # fig1.savefig("plots/densplot_{}.png".format(apx))
    # fig3.savefig("plots/w_{}.png".format(apx))  
        
    
        
        