"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

import jax.numpy as jnp
from jax.numpy import sin,cos,exp
import jax
from sampling import BallSampler,SphereSampler,SphereSampler3d,PlanarSampler
from training import Trainer
from plotting import plotVelDenSphere,plotStats,plotVelDenplanar
from utils import *
import jax.random as random
import optax
import numpy as np
import copy
from jax import jacfwd

#initial condition 
#is called with full time variable 

# def rho0_u0(x):
#     return jnp.array([3/2 - (x[1]**2 + x[2]**2+ x[3]**2),x[1]-1,0.5])


# def rho0_u0(x):
#     r = jnp.pi/4
#     #print("x",x.shape)
#     dis = jnp.minimum((x[0]-jnp.pi/2)**2+(x[1]-jnp.pi/4)**2,r**2)
#     dis = jnp.sqrt(dis)
#     u0 = jnp.sin((x[0]-jnp.pi/2)) * (r-dis)/r * 8
#     u1 = jnp.cos((x[1]-jnp.pi/4)) * (r-dis)/r * 8
    

#     return jnp.array([u0,u1])

def sh_(x):
    r0 = jnp.sqrt(x[0]**2+x[1]**2+x[2]**2)
    a = x[0]/r0
    b = x[1]/r0
    c = x[2]/r0
    r = jnp.sqrt(a**2+b**2+c**2)
    Y44 = 3*(8*r**4 - 40*r**2*a**2 - 40*r**2*b**2 + 35*a**4 + 70*a**2*b**2 + 35*b**4)/(16*jnp.sqrt(jnp.pi)*r**4)
    Y45 = 3*jnp.sqrt(10)*a*c*(4*r**2 - 7*a**2 - 7*b**2)/(8*jnp.sqrt(jnp.pi)*r**4)
    Y = 0.4 * Y44 + 0.6 * Y45
    return Y

def rho0_u0(x):
    killing_on_z = jnp.array([-x[1],x[0],0]) * 10
    sh_grad = jacfwd(sh_) 
    u = killing_on_z +  jnp.cross(sh_grad(x),x) * 5
    return u

def init_w0_planar(x):
    U = 1
    a = 0.28
    pts = jnp.array([[-0.4,0,0],[0.4,0,0]])
    n = jnp.array([0.3,-0.5,0.8])
    n = n/jnp.linalg.norm(n)
    rotation_m = jnp.array([[n[1]/jnp.sqrt(n[0]**2+n[1]**2),-n[0]/jnp.sqrt(n[0]**2+n[1]**2),0],
                            [n[0]*n[2]/jnp.sqrt(n[0]**2+n[1]**2),n[1]*n[2]/jnp.sqrt(n[0]**2+n[1]**2),-jnp.sqrt(n[0]**2+n[1]**2)],
                            [n[0],n[1],n[2]]
                            ]).transpose()
    pts = (rotation_m@pts.transpose()).transpose()
    w = 0
    for i in range(2):
        r = (x[0]-pts[i,0])**2 + (x[1]-pts[i,1])**2 + (x[2]-pts[i,2])**2
        w += U/a*(2-r/a**2)*jnp.exp(0.5*(1-r/a**2))
    return w

def w0(x,vc,U,a):
    w = 0
    ak = [1,-1]
    #jax.debug.print("{}",vc[0])
    for i in range(vc.shape[0]):
        r = jnp.linalg.norm(x - vc[i])
        w = w + (ak[i]*U/a)*jnp.exp(0.5*( 1 - (r**2)/(a**2)))
    return w

#(-0.1423 0.2637 0.9541)
#(0.1423 0.2637 0.9541)

def set_w0(x):
    U = -0.1
    a = 0.02
    #vc = np.array([[2.64674716679,1.0895543644],[0.49484548679,1.0895543644]])
    vc = jnp.array([[0.2637, 0.9541, -0.1423],[0.2637, 0.9541, 0.1423]])
    w = w0(x,vc,U,a)
    return w


# def vortices_func(mesh, vi, U, a):
#     X = mesh.vertices
#     v = X[vi,:]
#     w = jnp.zeros_like(mesh.nv)
#     for i in range(v.shape[0]):

#         vc =  np.tile(v[i,:], (mesh.nv, 1))
#         r = np.linalg.norm(X - vc, axis=1)

#         w = w + (U/a)*np.exp(0.5*( 1 - r**2/a**2))
#     return w

def runBallExperiment(params, key, pinn, apx, loss, pde, sched,time_step,advect_time_step,mlp,load_path=''):
    #define pde
    pde.setInitial(init_w0_planar)
    #pde.setInitial(set_w0)

    opt = optax.adam(learning_rate=sched)
    opt_st = opt.init(params)
    
    if not load_path == '':
        params, opt_st = loadState(load_path)

    # fix sampler.....
    smp = PlanarSampler(heal=False,T=0.5,N=1000)

    trainer = Trainer(opt,loss,smp,time_step)

    #tkey,key = random.split(key)
    #full run
    eps=1e3
    stats = []
    for t_i in range(advect_time_step):
        if t_i==0:
            for i in range(160):
                tkey,key = random.split(key)
                params, opt_st,stats = trainer.trainModel_init(params,tkey,t_i, opt_st, stats=stats,steps=int(eps))
                #print(stats[5:])
                plotStats(stats[5:],apx="3d_ball_experiment_" + str(t_i) + apx)
            #params = loadState('training_dumps/0_3904606723ncl_periods')
        else:

            #sched_2 = optax.exponential_decay(init_value = 1e-6,transition_steps=40000,decay_rate=1e-2)
            #sched_2 = optax.exponential_decay(init_value = 1e-4,transition_steps=120000,decay_rate=1e-2)
            sched_2 = optax.piecewise_constant_schedule(init_value=1e-5,
                                                boundaries_and_scales={40000:1e-1,
                                                                       60000:1e-2}
                                               )
            #sched_2 = optax.constant_schedule(1e-6)
            # param_n = copy.deepcopy(params)
            # params = param_n
            opt = optax.adam(learning_rate=sched_2)
            opt_st = opt.init(params)
            trainer.set_opt(opt)
            
            if t_i<7:
                iter_step = 80
            else:
                iter_step = 80
            
            # if t_i==47:
            #     params = loadState('training_dumps/47_1590622266ncl_periods')
            # if t_i>47:
            for i in range(iter_step):
                tkey,key = random.split(key)
                params, opt_st,stats = trainer.trainModel(params,param_b,params_init,tkey,t_i, opt_st, stats=stats,steps=int(eps))
                plotStats(stats[5:],apx="3d_ball_experiment_" + str(t_i) + apx)

        if t_i == 0:
            params_init = copy.deepcopy(params)
            # param_b = copy.deepcopy(params)

        # if t_i==47:
        #     param_b = copy.deepcopy(params)
        # if t_i>47:
        
        # else:
        time = t_i * time_step

        #plotVelDenSphere(lambda x: pinn(x,params),time,apx=apx + str(t_i))
        plotVelDenplanar(lambda x: pinn(x,params),time,apx=apx + str(t_i))
        param_b = copy.deepcopy(params)
        saveState(param_b,opt_st,stats,"training_dumps/" +str(t_i)+'_'+ apx)

        # tkey,key = random.split(key)
        # x = random.normal(tkey,shape=(3,))
        # params = mlp.init(tkey,x)

        # # # if t_i<11:
        # # #     scale = 8e-2
        # # # else:
        # # #     scale = 8e-1
        # # # params = jax.tree_map(lambda x: x*scale, params)
        # params = params.unfreeze()['params']




