"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

import jax.numpy as jnp
from jax.numpy import sin,cos,exp
import jax
from sampling import SphereSampler3dSpace
from training import Trainer
from plotting import plotVelDenSphere,plotStats,plotVelDenplanar
from utils import *
import jax.random as random
import optax
import numpy as np
import copy
from jax import jacfwd

import optax
import flax
import gzip
import struct
import array
import os
from jax.scipy.ndimage import map_coordinates

def sh_(x):
    r0 = jnp.sqrt(x[0]**2+x[1]**2+x[2]**2)
    a = x[0]/r0
    b = x[1]/r0
    c = x[2]/r0
    r = jnp.sqrt(a**2+b**2+c**2)
    Y44 = 3*(8*r**4 - 40*r**2*a**2 - 40*r**2*b**2 + 35*a**4 + 70*a**2*b**2 + 35*b**4)/(16*jnp.sqrt(jnp.pi)*r**4)
    Y45 = 3*jnp.sqrt(10)*a*c*(4*r**2 - 7*a**2 - 7*b**2)/(8*jnp.sqrt(jnp.pi)*r**4)
    Y = 0.4 * Y44 + 0.6 * Y45
    return Y

def rho0_u0(x):
    killing_on_z = jnp.array([-x[1],x[0],0]) * 10
    sh_grad = jacfwd(sh_) 
    u = killing_on_z +  jnp.cross(sh_grad(x),x) * 5
    return u

def init_w0_planar(x):
    U = 1
    a = 0.3
    pts = jnp.array([[-0.4,0,0],[0.4,0,0]])
    n = jnp.array([0.3,-0.5,0.8])
    n = n/jnp.linalg.norm(n)
    rotation_m = jnp.array([[n[1]/jnp.sqrt(n[0]**2+n[1]**2),-n[0]/jnp.sqrt(n[0]**2+n[1]**2),0],
                            [n[0]*n[2]/jnp.sqrt(n[0]**2+n[1]**2),n[1]*n[2]/jnp.sqrt(n[0]**2+n[1]**2),-jnp.sqrt(n[0]**2+n[1]**2)],
                            [n[0],n[1],n[2]]
                            ]).transpose()
    pts = (rotation_m@pts.transpose()).transpose()
    w = 0
    for i in range(2):
        r = (x[0]-pts[i,0])**2 + (x[1]-pts[i,1])**2 + (x[2]-pts[i,2])**2
        w += U/a*(2-r/a**2)*jnp.exp(0.5*(1-r/a**2))
    return w

def w0(x,vc,U,a):
    w = 0
    ak = [1,-1]
    #jax.debug.print("{}",vc[0])
    for i in range(vc.shape[0]):
        r = jnp.linalg.norm(x - vc[i])
        w = w + (ak[i]*U/a)*jnp.exp(0.5*( 1 - (r**2)/(a**2)))
    return w

#(-0.1423 0.2637 0.9541)
#(0.1423 0.2637 0.9541)

def set_w0(x):
    U = -0.1
    a = 0.02
    #vc = np.array([[2.64674716679,1.0895543644],[0.49484548679,1.0895543644]])
    vc = jnp.array([[0.2637, 0.9541, -0.1423],[0.2637, 0.9541, 0.1423]])
    w = w0(x,vc,U,a)
    return w

def w0_func(x):
    train_images = x[7:].reshape(28,28)
    n = jnp.array([0.3,-0.5,0.8])
    n = n/jnp.linalg.norm(n)
    rotation_m_inv = jnp.linalg.inv(jnp.array([[n[1]/jnp.sqrt(n[0]**2+n[1]**2),-n[0]/jnp.sqrt(n[0]**2+n[1]**2),0],
                            [n[0]*n[2]/jnp.sqrt(n[0]**2+n[1]**2),n[1]*n[2]/jnp.sqrt(n[0]**2+n[1]**2),-jnp.sqrt(n[0]**2+n[1]**2)],
                            [n[0],n[1],n[2]]
                            ]).transpose())
    
    x = x.at[:3].set(rotation_m_inv@x[:3])
    
    y = (x[:2] + 1)/2*train_images.shape[1]
    C = map_coordinates(train_images, y, order=1)
    w = jnp.tanh(C)
    return w



def runBallExperiment(params, key, pinn, apx, loss, pde, sched,time_step,advect_time_step,mlp,load_path=''):
    #define pde
    
    pde.setInitial(w0_func)
    #pde.setInitial(set_w0)

    opt = optax.adam(learning_rate=sched)
    opt_st = opt.init(params)
    
    if not load_path == '':
        params, opt_st = loadState(load_path)
        
    # fix sampler.....
    smp = SphereSampler3dSpace(heal=False,T=0.5,N=1000)
    trainer = Trainer(opt,loss,smp,time_step)

    eps=1e3
    stats = []
    time = 0
    # for t_i in range(advect_time_step):
    # if t_i==0:
    t_i=0
    for i in range(200):
        tkey,key = random.split(key)
        params, opt_st,stats = trainer.trainModel_init(params,tkey,t_i, opt_st, stats=stats,steps=int(eps))
        #print(stats[5:])
        plotStats(stats[5:],apx="3d_ball_experiment_" + str(t_i) + apx)
    #params = loadState('training_dumps/0_3904606723ncl_periods')

    plotVelDenSphere(smp,lambda x: pinn(x,params),time,apx=apx + str(t_i))
    param_b = copy.deepcopy(params)
    saveState(param_b,opt_st,stats,"training_dumps/" +str(t_i)+'_'+ apx)

