import jax
import jax.numpy as jnp
import numpy as np
from models import Siren
import jax.random as random
import flax
import optax

seed = np.random.randint(2**32)

layers = 5
flax.config.update('flax_return_frozendict', True)
key =  random.PRNGKey(seed)
print("Random initial seed:", seed)
x = random.normal(key,shape=(3,))
mlp = Siren(num_layers=layers,output_dim=1,w0=30,w0_first_layer=30,use_bias=True)
params = mlp.init(key,x)
params = params.unfreeze()['params']
func_mlp_ = lambda params,x: mlp.apply({'params':params}, x)

samples = jnp.array(np.load('./samples_.npy'))
max_samples = np.max(samples,axis=0)
min_samples = np.min(samples,axis=0)

center = (max_samples+min_samples)/2
scale = np.max(max_samples-center)
samples = (samples-center)/scale

samples_sdf = jnp.array(np.load('./sample_sdfs_.npy')).reshape(-1,1)
samples_normals = jnp.array(np.load('./samples_normals.npy'))
samples_masks = jnp.array(np.load('./sample_masks_.npy')).reshape(-1,1)
x = jnp.concatenate([samples,samples_sdf,samples_normals,samples_masks],axis=-1)
samples_nums =  x.shape[0]

import jax.numpy as np
from jax import grad, vmap, value_and_grad,jit,jacfwd
from jax import random
import optax
from functools import partial
from tqdm import tqdm


# def sdf(model_output, gt):
#     '''
#        x: batch of input coordinates
#        y: usually the output of the trial_soln function
#        '''
#     gt_sdf = gt['sdf']
#     gt_normals = gt['normals']

#     coords = model_output['model_in']
#     pred_sdf = model_output['model_out']

#     gradient = diff_operators.gradient(pred_sdf, coords)

#     # Wherever boundary_values is not equal to zero, we interpret it as a boundary constraint.
#     sdf_constraint = torch.where(gt_sdf != -1, pred_sdf, torch.zeros_like(pred_sdf))
#     inter_constraint = torch.where(gt_sdf != -1, torch.zeros_like(pred_sdf), torch.exp(-1e2 * torch.abs(pred_sdf)))
#     normal_constraint = torch.where(gt_sdf != -1, 1 - F.cosine_similarity(gradient, gt_normals, dim=-1)[..., None],
#                                     torch.zeros_like(gradient[..., :1]))
#     grad_constraint = torch.abs(gradient.norm(dim=-1) - 1)
#     # Exp      # Lapl
#     # -----------------
#     return {'sdf': torch.abs(sdf_constraint).mean() * 3e3,  # 1e4      # 3e3
#             'inter': inter_constraint.mean() * 1e2,  # 1e2                   # 1e3
#             'normal_constraint': normal_constraint.mean() * 1e2,  # 1e2
#             'grad_constraint': grad_constraint.mean() * 5e1}  # 1e1      # 5e1


def loss_func(params,x_):
    x = x_[:3]
    x_sdf = x_[3]
    x_normal = x_[4:7]
    mask = x_[7]
    sdf = func_mlp_(params,x)
    sdf_ = lambda x: func_mlp_(params,x)
    sdf_grad = jacfwd(sdf_)(x)[0]
    sdf_constraint = jnp.where(mask == 1, sdf, jnp.zeros_like(sdf))
    inter_constraint = jnp.where(mask == 1, jnp.zeros_like(sdf), jnp.exp(-1e2 * jnp.abs(sdf)))
    #jax.debug.print("{}",optax.cosine_similarity(sdf_grad, x_normal,epsilon=0.5))
    normal_constraint = jnp.where(mask == 1, 1 - optax.cosine_similarity(sdf_grad, x_normal,epsilon=0.5),
                                    jnp.zeros_like(sdf_grad[..., :1]))
    grad_constraint = jnp.abs(jnp.linalg.norm(sdf_grad,axis=-1) - 1)
    return jnp.mean(jnp.abs(sdf_constraint))* 3e3 + jnp.mean(inter_constraint) *1e2+jnp.mean(normal_constraint) *1e2 + jnp.mean(grad_constraint)*5e1

def loss_(params,x):
    loss = vmap(lambda x:loss_func(params,x))(x)
    return jnp.mean(loss)

def sampling(key,N):
    sample_indices = jax.random.choice(key,samples_nums,
                                    shape=(N,),
                                    replace=False)
    return sample_indices

    

@partial(jit,static_argnums=())
def train_step(params_1, pts_samples, opt_st):
    vg_loss = jit(value_and_grad(loss_))
    lval, lgrad = vg_loss(params_1,x[pts_samples])
    update, opt_st = opt.update(lgrad,opt_st,params_1)
    params_1 = optax.apply_updates(params_1, update)
    return params_1, opt_st, lval

log_rate  = 100
#trains the model given in the loss obj
N = 1000

def trainModel(params_1, key, opt_st,stats=[],steps=int(1e4),hyper_debug=False):

    run_loss = 0
    steps = (samples_nums//N)
    bar = tqdm(range(steps))
    
    keys = random.split(key,5)
    key = keys[0]
    pts_samples_full = jax.random.permutation(key, samples_nums)

    for i in bar:
        key,_ = random.split(key,2)
        params_1, opt_st, lval = train_step(params_1,pts_samples_full[i*N:(i+1)*N], opt_st)
        run_loss += lval.item() / log_rate
        if hyper_debug:
            print(i,lval.item())
        if not ((i + 1) % log_rate):
 
            bar.set_description("avg_loss:{:f}".format(run_loss))
            stats.append([run_loss])
            
            run_loss = 0

    return params_1, opt_st, stats

sched_2 = optax.exponential_decay(init_value = 1e-5,transition_steps=120000,decay_rate=1e-2)
#sched_2 = optax.constant_schedule(1e-6)
opt = optax.adam(learning_rate=sched_2)
opt_st = opt.init(params)


eps=1e3
stats = []

import matplotlib.pyplot as plt
def plotStats(stats,apx):
    fig,ax = plt.subplots(1,1,figsize=(10,5))
    
    ax.plot(stats)
    ax.set_yscale('log')
    ax.set_ylabel('loss')
    ax.set_xlabel('steps (x100)')
    fig.savefig("./{}.png".format(apx))
    plt.close(fig)

import pickle


def saveState(params, stats,path):
    with open(path + "_model",'wb') as f:
        pickle.dump(params,f)
    with open(path + "_stats",'wb') as f:
        pickle.dump(stats,f)

iter_step = 50000
for i in range(iter_step):
    tkey,key = random.split(key)
    params, opt_st,stats = trainModel(params, tkey, opt_st, stats=stats,steps=int(eps))
    plotStats(stats[5:],apx="siren_fit_")

    saveState(params,stats,"savings")
