"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

import jax.numpy as jnp
from jax.numpy import sin,cos,exp
import jax
from sampling import BallSampler,SphereSampler,SphereSampler3d,PlanarSampler,ExplicitMeshSampler,ImplicitMeshSampler
from training import Trainer
from plotting import plotVelDenSphere,plotStats,plotVelDenplanar,plot_exp_mesh
from utils import *
import jax.random as random
import optax
import numpy as np
import copy
from jax import jacfwd
import trimesh
import igl
#initial condition 
#is called with full time variable 

# def rho0_u0(x):
#     return jnp.array([3/2 - (x[1]**2 + x[2]**2+ x[3]**2),x[1]-1,0.5])


# def rho0_u0(x):
#     r = jnp.pi/4
#     #print("x",x.shape)
#     dis = jnp.minimum((x[0]-jnp.pi/2)**2+(x[1]-jnp.pi/4)**2,r**2)
#     dis = jnp.sqrt(dis)
#     u0 = jnp.sin((x[0]-jnp.pi/2)) * (r-dis)/r * 8
#     u1 = jnp.cos((x[1]-jnp.pi/4)) * (r-dis)/r * 8
    

#     return jnp.array([u0,u1])

def sh_(x):
    r0 = jnp.sqrt(x[0]**2+x[1]**2+x[2]**2)
    a = x[0]/r0
    b = x[1]/r0
    c = x[2]/r0
    r = jnp.sqrt(a**2+b**2+c**2)
    Y44 = 3*(8*r**4 - 40*r**2*a**2 - 40*r**2*b**2 + 35*a**4 + 70*a**2*b**2 + 35*b**4)/(16*jnp.sqrt(jnp.pi)*r**4)
    Y45 = 3*jnp.sqrt(10)*a*c*(4*r**2 - 7*a**2 - 7*b**2)/(8*jnp.sqrt(jnp.pi)*r**4)
    Y = 0.4 * Y44 + 0.6 * Y45
    return Y

def rho0_u0(x):
    killing_on_z = jnp.array([-x[1],x[0],0]) * 10
    sh_grad = jacfwd(sh_) 
    u = killing_on_z +  jnp.cross(sh_grad(x),x) * 5
    return u

def init_w0_mesh(x,geo):
    U = 1
    a = 0.3
    b = x[6:9]
    f = x[9:].astype(jnp.int32)
    w = 0
    for i in range(2):
        r = jnp.sum(jnp.stack([geo[i,f[0]],geo[i,f[1]],geo[i,f[2]]],axis=0) * b)
        w += U/a*(2-r**2/a**2)*jnp.exp(0.5*(1-r**2/a**2))
        # jax.debug.print("{} {}",i,w)
    return w

def init_w0_spike(x,geo):
    w = 0
    ak = [1,-1]
    U = -0.1
    a = 0.02
    b = x[6:9]
    f = x[9:].astype(jnp.int32)
    
    for i in range(2):
        r = jnp.sum(jnp.stack([geo[i,f[0]],geo[i,f[1]],geo[i,f[2]]],axis=-1) * b)
        w = w + (ak[i]*U/a)*jnp.exp(0.5*( 1 - (r**2)/(a**2)))
    return w    

def init_implicit(x):
    U = -0.1
    a = 0.2
    #vc = np.array([[2.64674716679,1.0895543644],[0.49484548679,1.0895543644]])
    vc = jnp.array([[ 0.10661555 , 0.11196036 ,-0.1899763 ] ,[-0.0386528,   0.11696957, -0.18989322]])
    #lucy?? vc = jnp.array([[ 0.07806439, -0.24506364, -0.12179182], [-0.10616396, -0.4880706 ,0.07443176]])
    w = w0(x,vc,U,a)    
    return w

# def mesh_init_w(x):
    
#     pass

# def mesh_init_u0(x):
    
#     pass

def w0(x,vc,U,a):
    w = 0
    ak = [1,-1]
    #jax.debug.print("{}",vc[0])
    for i in range(vc.shape[0]):
        r = jnp.linalg.norm(x - vc[i])
        w = w + (ak[i]*U/a)*jnp.exp(0.5*( 1 - (r**2)/(a**2)))
    return w

#(-0.1423 0.2637 0.9541)
#(0.1423 0.2637 0.9541)

def set_w0(x):
    U = -0.1
    a = 0.02
    #vc = np.array([[2.64674716679,1.0895543644],[0.49484548679,1.0895543644]])
    vc = jnp.array([[0.2637, 0.9541, -0.1423],[0.2637, 0.9541, 0.1423]])
    w = w0(x,vc,U,a)
    return w


# def vortices_func(mesh, vi, U, a):
#     X = mesh.vertices
#     v = X[vi,:]
#     w = jnp.zeros_like(mesh.nv)
#     for i in range(v.shape[0]):

#         vc =  np.tile(v[i,:], (mesh.nv, 1))
#         r = np.linalg.norm(X - vc, axis=1)

#         w = w + (U/a)*np.exp(0.5*( 1 - r**2/a**2))
#     return w

def runBallExperiment(params, key, pinn, apx, loss, pde, sched,time_step,advect_time_step,mlp,load_path=''):
    #define pde
    trimesh_loader = False 
    # (False)
    path  = './mesh_process/libigl-tutorial-data/hand.mesh'
    # (True)
    #path = './mesh_process/spot/spot/spot_quadrangulated.obj' 
    if trimesh_loader:
        mesh = trimesh.load(path)
        v_normal =  igl.per_vertex_normals(mesh.vertices, mesh.faces)
        face_normal = igl.per_face_normals(mesh.vertices, mesh.faces,np.array([1,1,1]/np.sqrt(3)))
        # fix sampler.....
        v = mesh.vertices
        f = mesh.faces
        a_0 = igl.exact_geodesic(v, f, np.array([2009]), np.arange(v.shape[0]))
        a_1 = igl.exact_geodesic(v, f, np.array([2007]), np.arange(v.shape[0]))
        a = jnp.stack([a_0,a_1],axis=0)
        v = jnp.array(mesh.vertices)
        v_normal = jnp.array(v_normal)
        f = jnp.array(mesh.faces)
        pde.set_geodesic(a)
        pde.setInitial(init_w0_spike)
    else:

        # v, t, f = igl.read_mesh(path)
        # f_k = f.reshape(-1)
        # e = np.zeros((v.shape[0]))
        # e[f_k] = 1
        # num_ = np.sum(e)
        # n_list = np.arange(num_)
        # k = np.zeros_like(v[:,0])
        # k[e>0] = n_list

        # idsx_1 = k[1485]
        # idsx_2 = k[4170]

        # f_k = k[f_k]
        # f = f_k.reshape(-1,3).astype(np.int64)
        # v = v[e>0]

        # v_mean = np.mean(v,axis=0)
        # v = (v-v_mean)*4 + v_mean
        # v_normal =  igl.per_vertex_normals(v, f)
        # face_normal = igl.per_face_normals(v, f, np.array([1,1,1]/np.sqrt(3)))

        # a_0 = igl.heat_geodesic(v, f,0.01, np.array([idsx_1]).astype(np.int64))
        # a_1 = igl.heat_geodesic(v, f,0.01, np.array([idsx_2]).astype(np.int64))
        # # a_0 = igl.exact_geodesic(v, f, np.array([idsx_1]).astype(np.int64), np.arange(v.shape[0]))
        # # a_1 = igl.exact_geodesic(v, f, np.array([idsx_2]).astype(np.int64), np.arange(v.shape[0]))     
        

        v, t, f = igl.read_mesh(path)

        f_k = f.reshape(-1)
        a = np.zeros((v.shape[0]))
        a[f_k] = 1

        v = v[a>0]
        v_mean = np.mean(v,axis=0)
        v = (v-v_mean)*4 + v_mean
        v_normal =  igl.per_vertex_normals(v, f)
        face_normal = igl.per_face_normals(v, f, np.array([1,1,1]/np.sqrt(3)))
        a_0 = igl.exact_geodesic(v, f,np.array([1674]).reshape(-1,1), np.arange(v.shape[0]).reshape(-1,1))
        a_1 = igl.exact_geodesic(v, f,np.array([4171]).reshape(-1,1), np.arange(v.shape[0]).reshape(-1,1))

        a = jnp.stack([a_0,a_1],axis=0)
        # fix sampler.....
        v = np.array(v)
        v_normal = jnp.array(v_normal)
        f = np.array(f)
        

        pde.set_geodesic(a)

        pde.setInitial(init_w0_mesh)
    #pde.setInitial(set_w0)




    opt = optax.adam(learning_rate=sched)
    opt_st = opt.init(params)
    
    if not load_path == '':
        params, opt_st = loadState(load_path)


    smp = ExplicitMeshSampler(v,f,face_normal,T=0.5,N=8000)

    implicit = False
    if implicit:
        pde.setInitial(init_implicit)
        smp = ImplicitMeshSampler('')

    trainer = Trainer(opt,loss,smp,time_step)

    #tkey,key = random.split(key)
    #full run
    eps=1e3
    stats = []
    for t_i in range(advect_time_step):
        if t_i==0:
            for i in range(200):
                tkey,key = random.split(key)
                params, opt_st,stats = trainer.trainModel_init(params,tkey,t_i, opt_st, stats=stats,steps=int(eps))
                #print(stats[5:])
                plotStats(stats[5:],apx="3d_ball_experiment_" + str(t_i) + apx)
            #params = loadState('training_dumps/0_3904606723ncl_periods')
        else:
            smp.N = 1200
            sched_2 = optax.exponential_decay(init_value = 1e-6,transition_steps=40000,decay_rate=1e-2)
            #sched_2 = optax.exponential_decay(init_value = 1e-4,transition_steps=120000,decay_rate=1e-2)
            # sched_2 = optax.piecewise_constant_schedule(init_value=1e-5,
            #                                     boundaries_and_scales={60000:1e-1,
            #                                                            80000:1e-2}
            #                                    )
            #sched_2 = optax.constant_schedule(1e-6)
            # param_n = copy.deepcopy(params)
            # params = param_n
            opt = optax.adam(learning_rate=sched_2)
            opt_st = opt.init(params)
            trainer.set_opt(opt)
            
            if t_i<7:
                iter_step = 40
            else:
                iter_step = 40
            
            # if t_i==47:
            #     params = loadState('training_dumps/47_1590622266ncl_periods')
            # if t_i>47:
            for i in range(iter_step):
                tkey,key = random.split(key)
                params, opt_st,stats = trainer.trainModel(params,param_b,params_init,tkey,t_i, opt_st, stats=stats,steps=int(eps))
                plotStats(stats[5:],apx="3d_ball_experiment_" + str(t_i) + apx)

        if t_i == 0:
            params_init = copy.deepcopy(params)
            # param_b = copy.deepcopy(params)

        # if t_i==47:
        #     param_b = copy.deepcopy(params)
        # if t_i>47:
        
        # else:
        time = t_i * time_step

        #plotVelDenSphere(lambda x: pinn(x,params),time,apx=apx + str(t_i))
        #plotVelDenplanar(lambda x: pinn(x,params),time,apx=apx + str(t_i))
        plot_exp_mesh(lambda x: pinn(x,params),time,v,v_normal,f,apx=apx + str(t_i))
        param_b = copy.deepcopy(params)
        saveState(param_b,opt_st,stats,"training_dumps/" +str(t_i)+'_'+ apx)

        # tkey,key = random.split(key)
        # x = random.normal(tkey,shape=(3,))
        # params = mlp.init(tkey,x)

        # # # if t_i<11:
        # # #     scale = 8e-2
        # # # else:
        # # #     scale = 8e-1
        # # # params = jax.tree_map(lambda x: x*scale, params)
        # params = params.unfreeze()['params']




