"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

import jax.numpy as jnp
from jax.numpy import sin,cos,exp
import jax
from sampling import BallSampler,SphereSampler,SphereSampler3d
from training import Trainer
from plotting import plotVelDenSphere,plotStats
from utils import *
import jax.random as random
import optax
import numpy as np
import copy
from jax import jacfwd

#initial condition 
#is called with full time variable 

# def rho0_u0(x):
#     return jnp.array([3/2 - (x[1]**2 + x[2]**2+ x[3]**2),x[1]-1,0.5])


# def rho0_u0(x):
#     r = jnp.pi/4
#     #print("x",x.shape)
#     dis = jnp.minimum((x[0]-jnp.pi/2)**2+(x[1]-jnp.pi/4)**2,r**2)
#     dis = jnp.sqrt(dis)
#     u0 = jnp.sin((x[0]-jnp.pi/2)) * (r-dis)/r * 8
#     u1 = jnp.cos((x[1]-jnp.pi/4)) * (r-dis)/r * 8
    

#     return jnp.array([u0,u1])

def sh_(x):
    r0 = jnp.sqrt(x[0]**2+x[1]**2+x[2]**2)
    a = x[0]/r0
    b = x[1]/r0
    c = x[2]/r0
    r = jnp.sqrt(a**2+b**2+c**2)
    Y44 = 3*(8*r**4 - 40*r**2*a**2 - 40*r**2*b**2 + 35*a**4 + 70*a**2*b**2 + 35*b**4)/(16*jnp.sqrt(jnp.pi)*r**4)
    Y45 = 3*jnp.sqrt(10)*a*c*(4*r**2 - 7*a**2 - 7*b**2)/(8*jnp.sqrt(jnp.pi)*r**4)
    Y = 0.4 * Y44 + 0.6 * Y45
    return Y

def rho0_u0(x):
    killing_on_z = jnp.array([-x[1],x[0],0]) * 10
    sh_grad = jacfwd(sh_) 
    u = killing_on_z +  jnp.cross(sh_grad(x),x) * 5
    return u


def w0(x,vc,U,a):
    w = 0
    ak = [1,-1]
    #jax.debug.print("{}",vc[0])
    for i in range(vc.shape[0]):
        r = jnp.linalg.norm(x - vc[i])
        w = w + (ak[i]*U/a)*jnp.exp(0.5*( 1 - (r**2)/(a**2)))
    return w

#(-0.1423 0.2637 0.9541)
#(0.1423 0.2637 0.9541)

def set_w0(x):
    U = -0.1
    a = 0.02
    #vc = np.array([[2.64674716679,1.0895543644],[0.49484548679,1.0895543644]])
    vc = jnp.array([[0.2637, 0.9541, -0.1423],[0.2637, 0.9541, 0.1423]])
    w = w0(x,vc,U,a)
    return w


# def vortices_func(mesh, vi, U, a):
#     X = mesh.vertices
#     v = X[vi,:]
#     w = jnp.zeros_like(mesh.nv)
#     for i in range(v.shape[0]):

#         vc =  np.tile(v[i,:], (mesh.nv, 1))
#         r = np.linalg.norm(X - vc, axis=1)

#         w = w + (U/a)*np.exp(0.5*( 1 - r**2/a**2))
#     return w

def runBallExperiment(params, key, pinn, apx, loss, pde, sched,time_step,advect_time_step,mlp,load_path=''):
    #define pde
    pde.setInitial(rho0_u0)
    #pde.setInitial(set_w0)

    opt = optax.adam(learning_rate=sched)
    opt_st = opt.init(params)
    
    if not load_path == '':
        params, opt_st = loadState(load_path)

    # fix sampler.....
    smp = SphereSampler3d(heal=False,T=0.5,N=1000)

    trainer = Trainer(opt,loss,smp,time_step)

    #tkey,key = random.split(key)
    #full run
    eps=1e3
    stats = []
    for t_i in range(advect_time_step):
        if t_i==0:
            for i in range(120):
                tkey,key = random.split(key)
                params, opt_st,stats = trainer.trainModel_init(params,tkey,t_i, opt_st, stats=stats,steps=int(eps))
                #print(stats[5:])
                plotStats(stats[5:],apx="3d_ball_experiment_" + str(t_i) + apx)
            #params = loadState('training_dumps/0_3904606723ncl_periods')
        else:

            sched_2 = optax.exponential_decay(init_value = 1e-6,transition_steps=40000,decay_rate=1e-2)
            #sched_2 = optax.exponential_decay(init_value = 1e-4,transition_steps=120000,decay_rate=1e-2)
            # sched_2 = optax.piecewise_constant_schedule(init_value=1e-5,
            #                                     boundaries_and_scales={60000:1e-1,
            #                                                            80000:1e-2}
            #                                    )
            #sched_2 = optax.constant_schedule(1e-6)
            # param_n = copy.deepcopy(params)
            # params = param_n

            opt = optax.adam(learning_rate=sched_2)
            opt_st = opt.init(params)
            trainer.set_opt(opt)
            
            if t_i<7:
                iter_step = 40
            else:
                iter_step = 40
            
            # if t_i==47:
            #     params = loadState('training_dumps/47_1590622266ncl_periods')
            # if t_i>47:
            for i in range(iter_step):
                tkey,key = random.split(key)
                params, opt_st,stats = trainer.trainModel(params,param_b,params_init,tkey,t_i, opt_st, stats=stats,steps=int(eps))
                plotStats(stats[5:],apx="3d_ball_experiment_" + str(t_i) + apx)

        if t_i == 0:
            params_init = copy.deepcopy(params)
            # param_b = copy.deepcopy(params)

        # if t_i==47:
        #     param_b = copy.deepcopy(params)
        # if t_i>47:
        
        # else:
        time = t_i * time_step

        plotVelDenSphere(lambda x: pinn(x,params),time,apx=apx + str(t_i))
        param_b = copy.deepcopy(params)
        saveState(param_b,opt_st,stats,"training_dumps/" +str(t_i)+'_'+ apx)

        # tkey,key = random.split(key)
        # x = random.normal(tkey,shape=(3,))
        # params = mlp.init(tkey,x)

        # # # if t_i<11:
        # # #     scale = 8e-2
        # # # else:
        # # #     scale = 8e-1
        # # # params = jax.tree_map(lambda x: x*scale, params)
        # params = params.unfreeze()['params']




