import sys

sys.path.append('../')
import jax.numpy as jnp
import equinox as eqx
import numpy as np
import optax
import time
from jax.nn import gelu, silu, tanh
from jax.lax import scan, stop_gradient
from jax import random, jit, vmap, grad,hessian
import os
import scipy
import matplotlib.pyplot as plt
import argparse
import jax

from utils import count_model_parameters

parser = argparse.ArgumentParser(description="PINN")
parser.add_argument("--datatype", type=str, default='poisson', help="type of data")
parser.add_argument("--n_interior", type=int, default=500,
                    help="the number of interior training dataset for each epochs")
parser.add_argument("--n_boundary", type=int, default=500,
                    help="the number of boundary training dataset for each epochs")
parser.add_argument("--dim", type=int, default=2, help="dim of the problem")
parser.add_argument("--n_valid", type=int, default=2, help="valid coef in dataset")
parser.add_argument("--ite", type=int, default=100, help="the number of iteration")
parser.add_argument("--epochs", type=int, default=500, help="the number of epochs")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument("--seed", type=int, default=0, help="the name")
parser.add_argument("--features", type=int, default=256, help='width of the network')
parser.add_argument("--layers", type=int, default=2, help='depth of the network')
parser.add_argument("--device", type=int, default=0, help="cuda number")
args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device)

def high_ifft(c_pred, k_set, x_test,batch_size=50000):
    '''

    :param c_pred: (N,)
    :param k_set: (N,d)
    :param x_test: (d,)
    :return:
    '''
    # u_pred1 = jnp.sum(c_pred * jnp.exp(1j * jnp.sum(k_set * x_test, axis=1)))
    # u_pred = jnp.sum(c_pred * jnp.prod(jnp.exp(1j * (k_set * x_test)), axis=1))

    N = k_set.shape[0]
    u_pred = 0.0
    for i in range((N + batch_size - 1) // batch_size):
        # Get batch slice indices
        start, end = i * batch_size, min((i + 1) * batch_size, N)

        # Compute batch contribution and accumulate
        u_pred += jnp.sum(c_pred[start:end] *
                          jnp.prod(jnp.exp(1j * k_set[start:end] * x_test), axis=1))

    #u_pred = u_pred1 - u_pred2
    return jnp.real(u_pred)

class PirateNet(eqx.Module):
    matrices1: list
    biases1: list
    matrices2: list
    biases2: list
    matrices3: list
    biases3: list
    matrices_end: jnp.ndarray
    biases_end: jnp.ndarray
    matrices_modified: list
    biases_modified: list
    alpha: jnp.ndarray
    B: jnp.ndarray = eqx.field(static=True)

    def __init__(self, N_features, N_layers,B,key):
        keys = random.split(key, N_layers+1)
        self.matrices1 = [
            random.normal(key, (N_features[0], N_features[1])) / jnp.sqrt((N_features[0]+N_features[1]) / 2)*0.01
            for key in keys[:-1]
        ]
        keys = random.split(keys[-1], N_layers+1)
        self.biases1 = [
            random.normal(key, (N_features[1],))*0.01
            for key in keys[:-1]
        ]
        keys = random.split(keys[-1], N_layers + 1)
        self.matrices2 = [
            random.normal(key, (N_features[1], N_features[1])) / jnp.sqrt((N_features[1] + N_features[1]) / 2)*0.01
            for key in keys[:-1]
        ]
        keys = random.split(keys[-1], N_layers + 1)
        self.biases2 = [
            random.normal(key, (N_features[1],))*0.01
            for key in keys[:-1]
        ]
        keys = random.split(keys[-1], N_layers + 1)
        self.matrices3 = [
            random.normal(key, (N_features[1], N_features[0])) / jnp.sqrt((N_features[1] + N_features[0]) / 2)*0.01
            for key in keys[:-1]
        ]
        keys = random.split(keys[-1], N_layers + 1)
        self.biases3 = [
            random.normal(key, (N_features[0],))*0.01
            for key in keys[:-1]
        ]
        self.matrices_modified = [
            random.normal(key, (N_features[0], N_features[1])) / jnp.sqrt((N_features[0] + N_features[1]) / 2)*0.01,
            random.normal(key, (N_features[0], N_features[1])) / jnp.sqrt((N_features[0] + N_features[1]) / 2)*0.01]

        self.biases_modified = [jnp.zeros((N_features[1],)), jnp.zeros((N_features[1],))]

        self.matrices_end = random.normal(key, (N_features[0], N_features[2])) / jnp.sqrt((N_features[0] + N_features[2]) / 2)
        self.biases_end = random.normal(key, (N_features[2],))

        self.B = B
        self.alpha = jnp.zeros((1,))

    def __call__(self, inputs):
        ## fourier features
        # periodic
        # inputs = jnp.hstack([jnp.cos(inputs), jnp.sin(inputs)])
        inputs=inputs/2/jnp.pi
        inputs = jnp.hstack([jnp.cos(inputs@self.B),jnp.sin(inputs@self.B)])
        u = inputs @ self.matrices_modified[0] + self.biases_modified[0]
        v = inputs @ self.matrices_modified[1] + self.biases_modified[1]

        u = tanh(u)
        v = tanh(v)

        for i in range(0, len(self.matrices1)):
            f = inputs @ self.matrices1[i] + self.biases1[i]
            f = tanh(f)
            f = f * u + (1 - f) * v

            f = f @ self.matrices2[i] + self.biases2[i]
            f = tanh(f)
            f = f * u + (1 - f) * v

            f = f @ self.matrices3[i] + self.biases3[i]
            f = tanh(f)

            inputs = self.alpha*f+(1-self.alpha)*inputs

        f = inputs @ self.matrices_end + self.biases_end
        return f


class interior_points():
    def __init__(self, dim, interval=(-1, 1)):
        self.dim = dim
        self.interval = interval

    def sample(self, num, key):
        points = random.uniform(key, (num, self.dim),minval=self.interval[0], maxval=self.interval[1])
        return points


class boundary_points():
    def __init__(self, dim, generate_data, interval=(-1, 1)):
        self.dim = dim
        self.interval = interval
        self.generate_data = generate_data

    def sample(self, num, key):
        keys = random.split(key, 3)
        x = jax.random.uniform(keys[0], shape=(num, self.dim), minval=self.interval[0],maxval=self.interval[1])
        boundary = jax.random.randint(keys[1], (num,self.dim), minval=0, maxval=2) * (self.interval[1] - self.interval[0]) + \
                   self.interval[0]
        idx_bd = jax.random.randint(keys[2], (num,self.dim), minval=0, maxval=2)
        x=jnp.where(idx_bd,boundary,x)
        y = vmap(self.generate_data,0)(x)
        return x, y


def net(model,*x):
    return model(jnp.stack([*x]))[0]


def residual(model, x, r_s):
    '''
    \Delta u=f
    :param model: Piranet
    :param x: input variables
    :param r_s: f
    :return: residual loss
    '''
    dim = x.shape[0]
    f = jnp.sum(jnp.stack([grad(grad(net, argnums=i + 1), argnums=i + 1)(model, *x) for i in range(dim)]))
    return f - r_s

def compute_loss(model, ob_x, ob_sup):
    res = vmap(residual, (None, 0, 0))(model, ob_x[:, :-1], ob_x[:, -1])
    r = (res ** 2).mean()

    ## periodic boundary condition u_T=u_R=g
    ob_b = vmap(model, (0,))(ob_sup[:, :-1])
    l_b = ((ob_b[:,0] - ob_sup[:, -1]) ** 2).mean()
    return r + 100 * l_b


compute_loss_and_grads = eqx.filter_value_and_grad(compute_loss)


@eqx.filter_jit
def make_step(model, ob_x, ob_sup, optim, opt_state):
    loss, grads = compute_loss_and_grads(model, ob_x, ob_sup)
    updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state


def train(key):
    keys = random.split(key, 4)
    # Get hyterparameters
    interval = [0, 2*np.pi]
    dim = args.dim
    N_interior = args.n_interior
    N_b = args.n_boundary
    N_epochs = args.epochs
    ite = args.ite
    learning_rate = args.lr
    # Generate sampled data

    print('Generating data')
    spectral_data = np.load(f'../data/test_point_high_{dim}_{args.n_valid}.npz')
    spectral_scale=1
    x_test = spectral_data['x_test']
    u_test = spectral_data['u_test']*spectral_scale
    k_full = spectral_data['k_full']
    f_nd_coef = spectral_data['f_nd_coef']*spectral_scale
    u_nd_coef = spectral_data['u_nd_coef']*spectral_scale


    generate_data_by_points = lambda x: high_ifft(u_nd_coef,k_full,x,batch_size=50)
    right_hand_side_by_points = lambda x: high_ifft(f_nd_coef,k_full,x,batch_size=50)
    x_b_set = boundary_points(dim=dim, generate_data=generate_data_by_points, interval=interval)
    x_in_set = interior_points(dim=dim, interval=interval)
    input_dim = dim
    output_dim = 1

    # Choose model
    hidden_dim=int(args.features/2) # input_dim
    B = random.normal(keys[0],(input_dim,hidden_dim))/dim
    N_features = [hidden_dim*2, args.features, output_dim]
    N_layers = args.layers
    model = PirateNet(N_features, N_layers, B,keys[2])
    total_params,_ = count_model_parameters(model)
    print(f'Total params: {total_params:.2e}')
    # N_features = [input_dim, args.features, output_dim]
    # N_layers = args.layers
    # model = MLP(N_features, N_layers, keys[2])

    # parameters of optimizer
    N_drop = 50000
    gamma = 0.9
    sc = optax.exponential_decay(learning_rate, N_drop, gamma)
    optim = optax.adam(learning_rate=sc)
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    keys = random.split(keys[-1], 3)
    input_points = x_in_set.sample(N_interior, keys[0])
    ob_x = jnp.concatenate([input_points,
                            vmap(right_hand_side_by_points, 0)(input_points).reshape(-1, 1)], -1)
    x_b, y_b = x_b_set.sample(N_b, keys[1])
    ob_sup = jnp.concatenate([x_b, y_b.reshape(-1, 1)], -1)

    errors = []
    error_min = 10
    print('starting training')
    T1 = time.time()
    for j in range(ite * N_epochs):
        # T11 = time.time()
        loss, model, opt_state = make_step(model, ob_x, ob_sup, optim, opt_state)
        # T12 = time.time()
        # print(f'time: {T12 - T11:.2f}' )
        if j % N_epochs == 0:
            keys = random.split(keys[-1], 3)
            input_points = x_in_set.sample(N_interior, keys[0])
            ob_x = jnp.concatenate([input_points,
                                    vmap(right_hand_side_by_points, 0)(input_points).reshape(-1, 1)], -1)
            x_b, y_b = x_b_set.sample(N_b, keys[1])
            ob_sup = jnp.concatenate([x_b, y_b.reshape(-1, 1)], -1)
            u_pred = vmap(model, (0,))(x_test)
            relative_error = jnp.linalg.norm(u_pred.flatten() - u_test.flatten()) / jnp.linalg.norm(u_test.flatten())
            errors.append(relative_error)
            print(f'epochs:{int(j/N_epochs)}, error_u: {relative_error:.2e}, loss: {loss:.2e}')
            print('++++++++++++++++++++++++')
            if relative_error < error_min:
                path = f'results/pinn_{args.dim}_{args.seed}_{args.n_valid}.eqx'
                eqx.tree_serialise_leaves(path, model)
                path = f'results/pinn_{args.dim}_{args.seed}_{args.n_valid}.npz'
                np.savez(path, u_pred=u_pred, u_test=u_test, errors=errors, x_test=x_test)
                error_min = relative_error
    T2 = time.time()
    execution_time = T2 - T1

    u_pred = vmap(model, (0,))(x_test)
    relative_error = jnp.linalg.norm(u_pred.flatten() - u_test.flatten()) / jnp.linalg.norm(u_test.flatten())
    print(f'u: {relative_error:.2e}')

    # save model and results
    if relative_error < error_min:
        path = f'results/pinn_{args.dim}_{args.seed}_{args.n_valid}.eqx'
        eqx.tree_serialise_leaves(path, model)
        path = f'results/pinn_{args.dim}_{args.seed}_{args.n_valid}.npz'
        np.savez(path, u_pred=u_pred, u_test=u_test, errors=errors,x_test=x_test)
        error_min = relative_error
    print(f'finial result u: {error_min:.2e},time:{execution_time:.6f}')


if __name__ == "__main__":
    seed = args.seed
    np.random.seed(seed)
    key = random.PRNGKey(seed)
    train(key)
