'''
  laplace u = u_t
'''

import sys
sys.path.append('../')
import jax.numpy as jnp
import equinox as eqx
import numpy as np
import optax
import time
from jax.nn import gelu, silu, tanh
from jax.lax import scan, stop_gradient
from jax import random, jit, vmap, grad
import os
import scipy
import matplotlib.pyplot as plt
import argparse
import jax

parser = argparse.ArgumentParser(description="FSPINN")
parser.add_argument("--datatype", type=str, default='poisson', help="type of data")
parser.add_argument("--ntest", type=int, default=1000, help="the number of testing dataset")
parser.add_argument("--n_interior", type=int, default=5000,
                    help="the number of interior training dataset for each epochs")
parser.add_argument("--n_boundary", type=int, default=5000,
                    help="the number of boundary training dataset for each epochs")
parser.add_argument("--n_initial", type=int, default=5000,
                    help="the number of initial training dataset for each epochs")
parser.add_argument("--T", type=float, default=1e-2, help="terminal time")
parser.add_argument("--dim", type=int, default=2, help="dim of the problem")
parser.add_argument("--ite", type=int, default=20, help="the number of iteration")
parser.add_argument("--epochs", type=int, default=50000, help="the number of epochs")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument("--seed", type=int, default=0, help="the name")
parser.add_argument("--features", type=int, default=100, help='width of the network')
parser.add_argument("--layers", type=int, default=5, help='depth of the network')
parser.add_argument("--max_k", type=int, default=3, help='max number of k')
parser.add_argument("--device", type=int, default=0, help="cuda number")
args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device)

def fourier_polynomial(k,x):
  assert k.shape==x.shape, 'wrong match k and x'
  return jnp.prod(jnp.exp(1j*k*x))

def get_cof_high(k,coe_1d,N_x):
  return jnp.prod(coe_1d[k])

def generate_data(x_test, coe_1d, k_1d, k_full):
    t_test = x_test[:,-1:]
    x_test = x_test[:,:-1]
    coeffs = jnp.prod(coe_1d[k_full], axis=1)
    decay = jnp.exp(-jnp.sum(k_1d[k_full] ** 2, axis=1)*t_test)
    phi = jnp.exp(1j * (k_1d[k_full] @ x_test.T))
    u_pred = jnp.real(jnp.sum(coeffs * decay * phi.T,axis=1))[:,jnp.newaxis]
    return u_pred

def right_hand_side(x,u0_1d_fft, k_full, k_1d,N_x):
    u = 0
    for idx in k_full:
        f_fft = get_cof_high(idx, u0_1d_fft,N_x)
        u += f_fft * fourier_polynomial(k_1d[idx], x)
    return jnp.stack([jnp.real(u)])

class MLP(eqx.Module):
    matrices: list
    biases: list

    def __init__(self, N_features, N_layers, key):
        keys = random.split(key, N_layers + 1)
        features = [N_features[0], ] + [N_features[1], ] * (N_layers - 1) + [N_features[-1], ]
        self.matrices = [random.normal(key, (f_in, f_out)) / jnp.sqrt((f_in + f_out) / 2) for f_in, f_out, key in
                         zip(features[:-1], features[1:], keys)]
        keys = random.split(keys[-1], N_layers)
        self.biases = [random.normal(key, (f_out,)) for f_in, f_out, key in zip(features[:-1], features[1:], keys)]

    def __call__(self, inputs):
        f = inputs @ self.matrices[0] + self.biases[0]
        for i in range(1, len(self.matrices)):
            f = tanh(f)
            f = f @ self.matrices[i] + self.biases[i]
        return f

class interior_points():
    def __init__(self, dim, interval=(-1, 1),T=1):
        self.dim = dim
        self.interval = interval
        self.T = T
    def sample(self, num, key):
        keys = random.split(key, 2)
        points = jax.random.uniform(keys[0], shape=(num, self.dim), minval=self.interval[0], maxval=self.interval[1])
        points_time = jax.random.uniform(keys[1], shape=(num, 1), minval=0, maxval=self.T)
        points = jnp.concatenate([points, points_time], axis=-1)
        return points


class boundary_points():
    def __init__(self, dim, generate_data, interval=(-1, 1),T=1):
        self.dim = dim
        self.interval = interval
        self.generate_data = generate_data
        self.T = T
    def sample(self, num, key):
        keys = random.split(key, 5)
        x = jax.random.uniform(keys[0], shape=(num, self.dim), minval=self.interval[0], maxval=self.interval[1])
        boundary = jax.random.randint(keys[1], (num,self.dim), minval=0, maxval=2) * (self.interval[1] - self.interval[0]) + \
                   self.interval[0]

        random_probs = jax.random.uniform(keys[2], shape=(num, self.dim))
        random_probs_min = jnp.min(random_probs,axis=1)[:,jnp.newaxis]
        mask_random = random_probs_min + jax.random.uniform(keys[3], shape=(num,1)) * (1 - random_probs_min)
        mask_bool = random_probs< jnp.repeat(mask_random,self.dim,axis=1)
        x = jnp.where(mask_bool, boundary, x)
        boundary_time = jax.random.uniform(keys[4], shape=(num, 1), minval=0, maxval=self.T)
        x=jnp.concatenate([x, boundary_time], axis=-1)
        y = self.generate_data(x)
        return x, y

class initial_points():
    def __init__(self, dim, generate_data, interval=(-1, 1)):
        self.dim = dim
        self.interval = interval
        self.generate_data = generate_data
    def sample(self, num, key):
        x = jax.random.uniform(key, shape=(num, self.dim), minval=self.interval[0], maxval=self.interval[1])
        boundary_time = jnp.zeros_like(x[:,0:1])
        x=jnp.concatenate([x, boundary_time], axis=-1)
        y = self.generate_data(x)
        return x, y

def net(model, *x):
    return model(jnp.stack([*x]))[0]


def residual(model, x):
    dim = x.shape[0]-1
    laplace_u = jnp.sum(jnp.stack([grad(grad(net, argnums=i + 1), argnums=i + 1)(model, *x) for i in range(dim)]))
    u_t = grad(net, argnums=dim + 1)(model, *x)
    return u_t - laplace_u


def boundary_and_initial(model, x):
    return net(model, *x)

def compute_loss(model, ob_x, ob_sup):
    res = vmap(residual, (None, 0))(model, ob_x)
    r = (res ** 2).mean()
    ob_u = vmap(boundary_and_initial, (None, 0))(model, ob_sup[:, :-1])
    l_b = ((ob_u - ob_sup[:, -1]) ** 2).mean()
    return r + 100 * l_b


compute_loss_and_grads = eqx.filter_value_and_grad(compute_loss)


@eqx.filter_jit
def make_step(model, ob_x, ob_sup, optim, opt_state):
    loss, grads = compute_loss_and_grads(model, ob_x, ob_sup)
    updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state


def train(key):
    keys = random.split(key, 4)
    N_x=100
    # Get hyterparameters
    interval = [0, 2*np.pi]
    dim = args.dim
    T=args.T
    # ntest = args.ntest
    N_interior = args.n_interior
    N_b = args.n_boundary
    N_0 = args.n_initial
    N_epochs = args.epochs
    ite = args.ite
    learning_rate = args.lr
    # Generate sampled data

    u0_1d_fft=np.load('data/fourier_t_1d.npz')['u0_1d_fft'] ## right hand side
    k_1d = np.around(np.fft.fftfreq(N_x) * N_x).astype('int32')

    filter_eps = (1e-8) ** (1 / dim)
    k_full = np.meshgrid(*([k_1d[abs(u0_1d_fft) > filter_eps]] * dim))
    k_full = np.concatenate([ks.reshape(-1, 1) for ks in k_full], -1)

    generate_data_by_points = lambda x: generate_data(x,k_1d=k_1d,k_full=k_full,coe_1d=u0_1d_fft)
    x_b_set = boundary_points(dim=dim, generate_data=generate_data_by_points, interval=interval,T=T)
    x_0_set = initial_points(dim=dim, generate_data=generate_data_by_points, interval=interval)
    x_in_set = interior_points(dim=dim, interval=interval)
    # x_test = jnp.concatenate([x_in_set.sample(num=int(ntest * 0.8), key=keys[1]),
    #                           x_b_set.sample(num=int(ntest * 0.2), key=keys[1])[0]], 0)
    print('generating data')
    x_test = np.load(f'data/test_point_t_{dim}.npz')['x_test']
    x_test = np.concatenate([x_test, T*np.ones_like(x_test[:,0:1])], axis=-1)
    u_test = np.load(f'data/test_point_t_{dim}.npz')['u_test']
    input_dim = dim+1
    output_dim = 1

    N_features = [input_dim, args.features, output_dim]
    N_layers = args.layers
    # Choose the model
    model = MLP(N_features, N_layers, keys[2])

    # parameters of optimizer
    N_drop = 50000
    gamma = 0.9
    sc = optax.exponential_decay(learning_rate, N_drop, gamma)
    optim = optax.adam(learning_rate=sc)
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    keys = random.split(keys[-1], 4)
    ob_x = x_in_set.sample(N_interior, keys[0])
    x_b, y_b = x_b_set.sample(N_b, keys[1])
    x_0, y_0 = x_0_set.sample(N_0, keys[2])
    ob_b = jnp.concatenate([x_b, y_b], -1)
    ob_0 = jnp.concatenate([x_0, y_0], -1)
    ob_sup = jnp.concatenate([ob_b, ob_0], 0)


    errors = []
    error_min = 1
    print('starting training')
    T1 = time.time()
    for j in range(ite * N_epochs):
        loss, model, opt_state = make_step(model, ob_x, ob_sup, optim, opt_state)
        if j % N_epochs == 0:
            keys = random.split(keys[-1], 4)
            ob_x = x_in_set.sample(N_interior, keys[0])
            x_b, y_b = x_b_set.sample(N_b, keys[1])
            x_0, y_0 = x_0_set.sample(N_0, keys[2])
            ob_b = jnp.concatenate([x_b, y_b], -1)
            ob_0 = jnp.concatenate([x_0, y_0], -1)
            ob_sup = jnp.concatenate([ob_b, ob_0], 0)

            u_pred = vmap(net, (None, 0))(model, x_test)
            relative_error = jnp.linalg.norm(u_pred.flatten() - u_test.flatten()) / jnp.linalg.norm(u_test.flatten())
            errors.append(relative_error)
            print(f'epochs:{int(j/N_epochs)}, error_u: {relative_error:.2e}, loss: {loss:.2e}')
            print('++++++++++++++++++++++++')
            if relative_error < error_min:
                path = f'results/heat_{args.dim}_{args.seed}.eqx'
                eqx.tree_serialise_leaves(path, model)
                path = f'results/heat_{args.dim}_{args.seed}.npz'
                np.savez(path, u_pred=u_pred, u_test=u_test, errors=errors)
                error_min = relative_error
    T2 = time.time()
    execution_time = T2 - T1

    u_pred = vmap(net, (None, 0))(model, x_test)
    relative_error = jnp.linalg.norm(u_pred.flatten() - u_test.flatten()) / jnp.linalg.norm(u_test.flatten())
    print(f'u: {relative_error:.2e}')

    # save model and results
    if relative_error < error_min:
        path = f'results/heat_{args.dim}_{args.seed}.eqx'
        eqx.tree_serialise_leaves(path, model)
        path = f'results/heat_{args.dim}_{args.seed}.npz'
        np.savez(path, u_pred=u_pred, u_test=u_test, errors=errors)
        error_min = relative_error
    print(f'finial result u: {error_min:.2e},time:{execution_time:.6f}')


if __name__ == "__main__":
    seed = args.seed
    np.random.seed(seed)
    key = random.PRNGKey(seed)
    train(key)

