'''
  u_x=f
'''
import sys
sys.path.append('../')

import jax.numpy as jnp
import equinox as eqx
import numpy as np
import optax
import time
from jax.nn import gelu, silu, tanh
from jax.lax import scan, stop_gradient, complex
from jax import random, jit, vmap, grad
import os
import scipy
import matplotlib.pyplot as plt
import argparse
import jax
from utils import find_common_vectors

# jax.config.update("jax_enable_x64", True)

parser = argparse.ArgumentParser(description="SINN")
parser.add_argument("--datatype", type=str, default='approximation', help="type of data")
parser.add_argument("--n_interior", type=int, default=8000,
                    help="the number of interior training dataset for each epochs")
parser.add_argument("--dim", type=int, default=2, help="dim of the problem")
parser.add_argument("--ite", type=int, default=20, help="the number of iteration")
parser.add_argument("--epochs", type=int, default=50000, help="the number of epochs")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument("--seed", type=int, default=0, help="the name")
parser.add_argument("--activation", type=str, default='tanh', help='the activation function')
parser.add_argument("--features", type=int, default=100, help='width of the network')
parser.add_argument("--layers", type=int, default=5, help='depth of the network')
parser.add_argument("--n_fourier", type=int, default=4, help='numer of sparse grid')
parser.add_argument("--gamma", type=float, default=-1.0, help="the name")
parser.add_argument("--alp", type=float, default=10, help="the name")
parser.add_argument("--nmax", type=int, default=50, help="the name")
parser.add_argument("--percentage", type=float, default=1.0, help="training percentage")
parser.add_argument("--spectral_scale", type=int, default=1, help="scale help training")
parser.add_argument("--device", type=int, default=2, help="cuda number")
args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device)

def get_k_possibility(k_set,gam,N_max,alp):
    dis = (np.prod(np.abs(k_set)+1,axis=1) * np.max(np.abs(k_set)+1,axis=1)**(-gam) - N_max**(1-gam))
    dis = alp * dis / dis.max()
    dis = jax.nn.elu(dis)
    # dis = jax.nn.silu(dis)
    k_normalized = 1 / (dis + 2)
    k_normalized = np.squeeze(k_normalized) / np.sum(k_normalized)
    return k_normalized

@jit
def fourier_polynomial(k, X,L):
    """计算缩放后的傅里叶多项式"""
    X=X/L*jnp.pi # x\in [-pi,pi]
    X = X + jnp.pi # x\in [0,2*pi]
    return jnp.exp(1j * (k @ (X.T)))

@jit
def process_chunk(chunk, test_points, coeffs_chunk,L):
    kdq_chunk = chunk
    if len(kdq_chunk) == 0:
        return jnp.zeros(test_points.shape[0], dtype=jnp.complex64)

    phi_chunk = fourier_polynomial(kdq_chunk, test_points,L)
    V_chunk = coeffs_chunk[:, None] * phi_chunk
    initial_contrib = jnp.sum(V_chunk, axis=0)

    return initial_contrib


def compute_u(test_points, coeffs, kdq, chunk_size,L):
    n_k = len(kdq)
    n_points = test_points.shape[0]
    initial_history = jnp.zeros(n_points, dtype=jnp.complex64)

    for i in range(0, n_k, chunk_size):
        end_idx = min(i + chunk_size, n_k)
        kdq_chunk = kdq[i:end_idx]
        coeffs_chunk = coeffs[i:end_idx]

        initial_contrib = process_chunk(
            kdq_chunk, test_points, coeffs_chunk,L
        )

        initial_history += initial_contrib

        if (i // chunk_size) % 10 == 0:
            print(f"Processed {min(i + chunk_size, n_k)}/{n_k} k values")

    return initial_history

class MLP(eqx.Module):
    matrices: list
    biases: list
    is_t: bool
    matrices_t: list
    N_fourier: int
    dim: int
    alpha: jnp.ndarray
    activation: jax.nn
    def __init__(self, N_features, N_layers, N_fourier,activation,depend_t,key):
        keys = random.split(key, N_layers + 1)
        features = [N_features[0], ] + [N_features[1], ] * (N_layers - 1) + [N_features[-1], ]
        self.matrices = [random.normal(key, (f_in, f_out)) / jnp.sqrt((f_in + f_out) / 2) for f_in, f_out, key in
                         zip(features[:-1], features[1:], keys)]
        keys = random.split(keys[-1], N_layers+1)
        self.biases = [random.normal(key, (f_out,)) for f_in, f_out, key in zip(features[:-1], features[1:], keys)]
        self.is_t=depend_t

        if depend_t:
            dim=N_features[0]-1
        else:
            dim = N_features[0]
        keys = random.split(keys[-1], dim+1)
        self.N_fourier = dim*N_fourier
        self.dim = dim
        features_t = [2*(self.N_fourier-1), ] + [N_features[1], ] * jnp.minimum(2,dim-1).item() + [N_features[-1], ]
        self.matrices_t =  [random.normal(key, (f_in, f_out)) / jnp.sqrt((f_in + f_out) / 2) for f_in, f_out, key in
                         zip(features_t[:-1], features_t[1:], keys)]
        self.alpha = jnp.ones(dim)
        if activation == 'silu':
            self.activation = silu
        elif activation == 'tanh':
            self.activation = tanh
        elif activation == 'gelu':
            self.activation = gelu
        else:
            assert False, f"unknown activation {activation}"


    def __call__(self, inputs):
        if self.is_t:
            x_set=jnp.pi/self.N_fourier*jnp.arange(1,self.N_fourier).reshape(-1,1)
            fourier_features=x_set*inputs[:-1]
            fourier_scale = jnp.exp(-jnp.linalg.norm(self.alpha**2*inputs[:-1],ord=1)/self.dim)
        else:
            x_set=jnp.pi/self.N_fourier*jnp.arange(1,self.N_fourier).reshape(-1,1)
            fourier_features=x_set*inputs
            fourier_scale = jnp.exp(-jnp.linalg.norm(self.alpha**2*inputs,ord=1)/self.dim)
        fourier_input = jnp.concatenate([jnp.cos(fourier_features),jnp.sin(fourier_features)],axis=0).T

        c = fourier_input @ self.matrices_t[0]
        for i in range(1,len(self.matrices_t)):
            c=c@self.matrices_t[i]
            # c=tanh(c)
        f = inputs @ self.matrices[0] + self.biases[0]
        for i in range(1, len(self.matrices)):
            f = self.activation(f)
            f = f @ self.matrices[i] + self.biases[i]
        return jnp.prod(c,axis=0)*f*fourier_scale

def real_net(model, x):
    return model(x)[0]

def imag_net(model, x):
    return model(x)[1]

def compute_loss(model, ob_x):
    pred_x_real = vmap(real_net, (None, 0,))(model, ob_x[:, :-2])
    pred_x_imag = vmap(imag_net, (None, 0,))(model, ob_x[:, :-2])
    r = ((pred_x_real-ob_x[:,-2]) ** 2).mean()+((pred_x_imag-ob_x[:,-1]) ** 2).mean()
    return r

compute_loss_and_grads = eqx.filter_value_and_grad(compute_loss)


@eqx.filter_jit
def make_step(model, ob_x, optim, opt_state):
    loss, grads = compute_loss_and_grads(model, ob_x)
    updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

def train(key):
    keys = random.split(key, 4)
    # Get hyterparameters
    dim = args.dim
    N_epochs = args.epochs
    ite = args.ite
    learning_rate = args.lr
    N_interior=args.n_interior
    spectral_scale=args.spectral_scale

    ## generate test points
    data=np.load(f"/workspace/mnt/local/data/tianchi.yu/SINN_high/schrodinger/{dim}D_harmonic_oscillator.npz")
    u_1d_fourier = data['phi_fft']
    N_x = u_1d_fourier.shape[0]
    k_1d = np.around(np.fft.fftfreq(N_x)*N_x).astype('int32')
    # kdq_high = data['kdq_high']
    kdq_low = data['kdq_low']
    k_full = data['k_full']
    c_test = data['c_test']*spectral_scale
    u_target = data['u_target']*spectral_scale
    gridx = data['gridx']
    sample_index=data['sample_index']
    sample_x = gridx[sample_index]
    L = abs(gridx[0])
    # _,index=find_common_vectors(k_full,kdq_high)
    # _,index=find_common_vectors(k_full,kdq_low)
    # c_test=c_test[index]
    # k_full=k_full[index]
    u_pred_high = compute_u(sample_x, c_test, k_full, 50000, L)
    error_high = np.linalg.norm(u_target - u_pred_high) / np.linalg.norm(u_target)
    print(f"error with high resolution: {error_high:.2e}")

    k_normalized = get_k_possibility(k_full,args.gamma,args.nmax,args.alp)
    indices = random.choice(keys[0], len(k_full), shape=(np.round(len(k_full) * args.percentage).astype('int32'),),
                            replace=False,p=k_normalized)
    all_indices = jnp.arange(len(k_full))
    mask = ~jnp.isin(all_indices, np.array(indices))

    kdq_low = k_full[indices]
    kdq_low_c = c_test[indices]

    u_pred_low = compute_u(sample_x, kdq_low_c, kdq_low, 50000, L)
    error_low = np.linalg.norm(u_target - u_pred_low) / np.linalg.norm(u_target)
    print(f"error with low resolution: {error_low:.2e}")

    kdq_test = k_full[mask]
    kdq_test_c = c_test[mask]

    print('construct model')
    input_dim = dim
    output_dim = 2

    N_features = [input_dim, args.features, output_dim]
    N_layers = args.layers
    N_fourier = args.n_fourier
    # Choose the model
    model = MLP(N_features, N_layers, N_fourier, args.activation,False, keys[2])

    # parameters of optimizer
    N_drop = 50000
    gamma = 0.9
    sc = optax.exponential_decay(learning_rate, N_drop, gamma)
    optim = optax.adam(learning_rate=sc)
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    k_set= kdq_low.astype('float32')

    kdq_test=kdq_test.astype('float32')
    keys = random.split(keys[-1], 3)
    training_points = jnp.concatenate([k_set, jnp.real(kdq_low_c).reshape(-1,1),jnp.imag(kdq_low_c).reshape(-1,1)], -1)

    try:
        input_points = random.choice(keys[0], training_points, shape=(N_interior,), replace=False)
    except:
        input_points = training_points

    errors = []
    error_min = 1
    print('done! Training')
    T1 = time.time()
    for j in range(ite * N_epochs):
        loss, model, opt_state = make_step(model, input_points, optim, opt_state)
        if j % N_epochs == 0:
            keys = random.split(keys[-1], 2)
            try:
                input_points = random.choice(keys[0], training_points, shape=(N_interior,), replace=False)
            except:
                input_points = training_points
            c_pred_partial = complex(vmap(real_net, (None, 0,))(model, kdq_test),vmap(imag_net, (None, 0,))(model, kdq_test))
            relative_error_mask = np.linalg.norm(c_pred_partial.flatten() - kdq_test_c.flatten())/np.linalg.norm(kdq_test_c.flatten())
            c_pred=c_test.copy()
            c_pred[mask]=c_pred_partial
            relative_error_c = jnp.linalg.norm(c_pred.flatten() - c_test.flatten()) / jnp.linalg.norm(c_test.flatten())
            u_pred_sinn = compute_u(sample_x,c_pred,k_full,50000,L)
            relative_error = np.linalg.norm(u_pred_sinn - u_target) / np.linalg.norm(u_target)
            errors.append(relative_error)
            print(f'epochs:{int(j/N_epochs)}, error_u: {relative_error:.2e}, loss:{loss:.2e}, error_c:{relative_error_c:.2e}, '
                  f'relative_error_mask:{relative_error_mask:.2e}')
            print('++++++++++++++++++++++++')
            if relative_error < error_min:
                path = f'/workspace/mnt/local/data/tianchi.yu/SINN_high/schrodinger/schrodinger_fft_{args.dim}_{args.seed}.eqx'
                eqx.tree_serialise_leaves(path, model)
                path = f'/workspace/mnt/local/data/tianchi.yu/SINN_high/schrodinger/schrodinger_fft_{args.dim}_{args.seed}.npz'
                np.savez(path, errors=errors,c_pred=c_pred,c_test=c_test,kdq_low=kdq_low,
                         kdq_test=kdq_test,indices=indices,mask=mask)
                error_min = relative_error
    T2 = time.time()
    execution_time = T2 - T1
    c_pred_partial = complex(vmap(real_net, (None, 0,))(model, kdq_test), vmap(imag_net, (None, 0,))(model, kdq_test))
    relative_error_mask = np.linalg.norm(c_pred_partial.flatten() - kdq_test_c.flatten()) / np.linalg.norm(
        kdq_test_c.flatten())
    c_pred = c_test.copy()
    c_pred[mask] = c_pred_partial
    relative_error_c = jnp.linalg.norm(c_pred.flatten() - c_test.flatten()) / jnp.linalg.norm(c_test.flatten())
    u_pred_sinn = compute_u(sample_x, c_pred, k_full, 50000, L)
    relative_error = np.linalg.norm(u_pred_sinn - u_target) / np.linalg.norm(u_target)
    errors.append(relative_error)
    print(f'epochs:{int(ite)}, error_u: {relative_error:.2e}, loss:{loss:.2e}, error_c:{relative_error_c:.2e}, '
                  f'relative_error_mask:{relative_error_mask:.2e}')
    if relative_error < error_min:
        path = f'/workspace/mnt/local/data/tianchi.yu/SINN_high/schrodinger/schrodinger_fft_{args.dim}_{args.seed}.eqx'
        eqx.tree_serialise_leaves(path, model)
        path = f'/workspace/mnt/local/data/tianchi.yu/SINN_high/schrodinger/schrodinger_fft_{args.dim}_{args.seed}.npz'
        np.savez(path, errors=errors, c_pred=c_pred, c_test=c_test, kdq_low=kdq_low,
                 kdq_test=kdq_test, indices=indices,mask=mask)
        error_min = relative_error
    print(f'finial result u: {error_min:.2e},time:{execution_time:.6f}')


if __name__ == "__main__":
    seed = args.seed
    np.random.seed(seed)
    key = random.PRNGKey(seed)
    train(key)





