import sys

sys.path.append('../')
import jax.numpy as jnp
import equinox as eqx
import numpy as np
import optax
import time
from jax.nn import gelu, silu, tanh
from jax.lax import scan, stop_gradient, complex
from jax import random, jit, vmap, grad
import os
import scipy
import matplotlib.pyplot as plt
import argparse
import jax
from utils import chebyshev_shift_sparse

parser = argparse.ArgumentParser(description="SINN")
parser.add_argument("--datatype", type=str, default='convection', help="type of data")
parser.add_argument("--n_interior", type=int, default=5000,
                    help="the number of interior training dataset for each epochs")
parser.add_argument("--n_t", type=int, default=10,
                    help="the number of interior training dataset for each epochs")
parser.add_argument("--dim", type=int, default=2, help="dim of the problem")
parser.add_argument("--T", type=float, default=1e-2, help="terminal time")
parser.add_argument("--ite", type=int, default=20, help="the number of iteration")
parser.add_argument("--epochs", type=int, default=50000, help="the number of epochs")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument("--seed", type=int, default=0, help="the name")
parser.add_argument("--activation", type=str, default='tanh', help='the activation function')
parser.add_argument("--features", type=int, default=100, help='width of the network')
parser.add_argument("--layers", type=int, default=5, help='depth of the network')
parser.add_argument("--n_fourier", type=int, default=4, help='numer of sparse grid')
parser.add_argument("--max_k", type=int, default=3, help='max number of k')
parser.add_argument("--gamma", type=float, default=-2.0, help="the name")
parser.add_argument("--alp", type=float, default=100, help="the name")
parser.add_argument("--nmax", type=int, default=50, help="the name")
parser.add_argument("--percentage", type=float, default=1.0, help="training percentage")
parser.add_argument("--spectral_scale", type=int, default=1, help="scale help training")
parser.add_argument("--device", type=int, default=0, help="cuda number")
args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device)

def get_k_possibility(k_set,gam,N_max,alp):
    dis = (np.prod(np.abs(k_set)+1,axis=1) * np.max(np.abs(k_set)+1,axis=1)**(-gam) - N_max**(1-gam))
    dis = alp * dis / dis.max()
    dis = jax.nn.elu(dis)
    # dis = jax.nn.silu(dis)
    k_normalized = 1 / (dis + 2)
    k_normalized = np.squeeze(k_normalized) / np.sum(k_normalized)
    return k_normalized

def generate_test_k(k_set):
    ## remove repeated elements
    k_test_set=jnp.unique(k_set,axis=0)
    index=jnp.linalg.norm(k_test_set,axis=1)==0
    k_test_set = k_test_set[~index]
    return k_test_set

# def shrink_k(k,scale):
#     return jnp.around(k/scale)

def extend_arccos(n,x):
    # 定义三种情况的表达式
    # 1. x ∈ [-1, 1]：使用三角函数表示
    inside = jnp.cos(n * jnp.arccos(x))
    # 2. x > 1：使用双曲函数表示
    greater = jnp.cosh(n * jnp.arccosh(x))
    # 3. x < -1：利用奇偶性和双曲函数表示
    less = (-1) ** n * jnp.cosh(n * jnp.arccosh(-x))
    result = jnp.where(
        x > 1,
        greater,
        jnp.where(
            x < -1,
            less,
            inside
        )
    )
    return result

def high_icheby(c_pred, k_set, x_test,batch_size=50000):
    '''

    :param c_pred: (N,)
    :param k_set: (N,d)
    :param x_test: (d,)
    :return:
    '''
    # u_pred1 = jnp.sum(c_pred * jnp.exp(1j * jnp.sum(k_set * x_test, axis=1)))
    # u_pred = jnp.sum(c_pred * jnp.prod(jnp.exp(1j * (k_set * x_test)), axis=1))

    N = k_set.shape[0]
    u_pred = 0.0
    for i in range((N + batch_size - 1) // batch_size):
        # Get batch slice indices
        start, end = i * batch_size, min((i + 1) * batch_size, N)

        # Compute batch contribution and accumulate
        # u_pred += jnp.sum(c_pred[start:end] * jnp.prod(jnp.cos(k_set[start:end] * jnp.arccos(x_test)), axis=1))
        u_pred += jnp.sum(c_pred[start:end] * jnp.prod(extend_arccos(k_set[start:end],x_test), axis=1))
    return jnp.real(u_pred)

class MLP(eqx.Module):
    matrices: list
    biases: list
    is_t: bool
    matrices_t: list
    N_fourier: int
    dim: int
    alpha: jnp.ndarray
    activation: jax.nn
    def __init__(self, N_features, N_layers, N_fourier,activation,depend_t,key):
        keys = random.split(key, N_layers + 1)
        features = [N_features[0], ] + [N_features[1], ] * (N_layers - 1) + [N_features[-1], ]
        self.matrices = [random.normal(key, (f_in, f_out)) / jnp.sqrt((f_in + f_out) / 2) for f_in, f_out, key in
                         zip(features[:-1], features[1:], keys)]
        keys = random.split(keys[-1], N_layers+1)
        self.biases = [random.normal(key, (f_out,)) for f_in, f_out, key in zip(features[:-1], features[1:], keys)]
        self.is_t=depend_t

        if depend_t:
            dim=N_features[0]-1
        else:
            dim = N_features[0]
        keys = random.split(keys[-1], dim+1)
        self.N_fourier = dim*N_fourier
        self.dim = dim
        features_t = [2*(self.N_fourier-1), ] + [N_features[1], ] * jnp.minimum(2,dim-1).item()+ [N_features[-1], ]
        self.matrices_t =  [random.normal(key, (f_in, f_out)) / jnp.sqrt((f_in + f_out) / 2) for f_in, f_out, key in
                         zip(features_t[:-1], features_t[1:], keys)]
        self.alpha = jnp.ones(dim)
        if activation == 'silu':
            self.activation = silu
        elif activation == 'tanh':
            self.activation = tanh
        elif activation == 'gelu':
            self.activation = gelu
        else:
            assert False, f"unknown activation {activation}"


    def __call__(self, inputs):
        if self.is_t:
            x_set=jnp.pi/self.N_fourier*jnp.arange(1,self.N_fourier).reshape(-1,1)
            fourier_features=x_set*inputs[:-1]
            fourier_scale = jnp.exp(-jnp.linalg.norm(self.alpha**2*inputs[:-1],ord=1)/self.dim)
        else:
            x_set=jnp.pi/self.N_fourier*jnp.arange(1,self.N_fourier).reshape(-1,1)
            fourier_features=x_set*inputs
            fourier_scale = jnp.exp(-jnp.linalg.norm(self.alpha**2*inputs,ord=1)/self.dim)
        fourier_input = jnp.concatenate([jnp.cos(fourier_features),jnp.sin(fourier_features)],axis=0).T

        c = fourier_input @ self.matrices_t[0]
        for i in range(1,len(self.matrices_t)):
            c=c@self.matrices_t[i]
            c=tanh(c)
        f = inputs @ self.matrices[0] + self.biases[0]
        for i in range(1, len(self.matrices)):
            f = self.activation(f)
            f = f @ self.matrices[i] + self.biases[i]
        return jnp.prod(c)*f*fourier_scale
        # return jnp.prod(c)*f

def net(model, x,t):
    return model(jnp.hstack([x,t]))[0]


def residual_boundary(model, x,t,matrices,ob_b):
    ## real
    diff_matrices=matrices['diff_matrices']
    diff_matrices_b = matrices['diff_matrices_b']
    u_cheby = vmap(net,(None,0,None))(model,x,t)
    u_t = vmap(grad(net, argnums=2),(None,0,None))(model, x,t)
    r = jnp.dot(diff_matrices,u_cheby).squeeze() - u_t.squeeze()
    rb = jnp.dot(diff_matrices_b, u_cheby).squeeze() - ob_b.squeeze()
    return jnp.hstack([r,rb])

# def boundary(model, x,t,matrices,ob_b):
#     diff_matrices_b=matrices['diff_matrices_b']
#     u_cheby = vmap(net, (None, 0, None))(model, x, t)
#     rb=jnp.dot(diff_matrices_b,u_cheby).squeeze() - ob_b.squeeze()
#     return rb

def compute_loss(model, input_points, input_time, ob_0,ob_b,matrices):
    r = vmap(residual_boundary, (None, None, 0, None,0))(model, input_points, input_time, matrices,ob_b)
    r = (r**2).mean()

    ob_u_0 = vmap(model,(0,))(ob_0[:,:-1])
    r_0 = ((ob_u_0[:,0]-ob_0[:,-1])**2).mean()

    return r +100*r_0


compute_loss_and_grads = eqx.filter_value_and_grad(compute_loss)


@eqx.filter_jit
def make_step(model,input_points, input_time, ob_sup, ob_b, matrices, optim, opt_state):
    loss, grads = compute_loss_and_grads(model, input_points, input_time, ob_sup,ob_b, matrices)
    updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state


def train(key):
    keys = random.split(key, 4)
    # Get hyterparameters
    dim = args.dim
    N_interior = args.n_interior
    N_epochs = args.epochs
    N_t = args.n_t
    ite = args.ite
    learning_rate = args.lr
    T=args.T
    spectral_scale=args.spectral_scale
    print('Generating data')
    data = jnp.load(f"../data/test_point_cheby_t_{dim}.npz")
    x_test = data["x_test"]
    u_test = data["u_test"]*spectral_scale
    diff_matrices = data["diff_matrices"]
    diff_matrices_b = data["diff_matrices_b"]
    k_full = data["k_full"]
    u0_nd_cheb = data["u0_nd_cheb"]*spectral_scale
    u_nd_cheb_T = data["u_nd_cheb_T"]*spectral_scale
    matrices={}
    matrices['diff_matrices']=diff_matrices
    matrices['diff_matrices_b']=diff_matrices_b

    input_dim = dim+1
    output_dim = 1

    N_features = [input_dim, args.features, output_dim]
    N_layers = args.layers
    N_fourier = args.n_fourier
    # Choose the model
    model = MLP(N_features, N_layers,N_fourier,args.activation,True, keys[2])

    # parameters of optimizer
    N_drop = 50000
    gamma = 0.9
    sc = optax.exponential_decay(learning_rate, N_drop, gamma)
    optim = optax.adam(learning_rate=sc)
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    k_set= k_full.astype('float32')
    k_test_set=np.concatenate([k_set,T*jnp.ones_like(k_set[:,0:1])],axis=-1)
    k_normalized = get_k_possibility(k_set,args.gamma,args.nmax,args.alp)

    ob_0 = jnp.concatenate([k_set,jnp.zeros_like(k_set[:,0:1]),u0_nd_cheb.reshape(-1,1)], axis=-1)
    if args.percentage<1:
        keys=random.split(keys[-1], 2)
        indices = random.choice(keys[0], k_set.shape[0], shape=(np.round(k_set.shape[0]*args.percentage).astype('int32'),), replace=False)
        np.savez(f'indices_{dim}_{args.percentage}_{args.seed}.npz', indices=indices)
        ob_0=ob_0[indices]

    keys = random.split(keys[-1], 3)
    try:
        input_points = random.choice(keys[0], k_set, shape=(N_interior,),replace=False,p=k_normalized)
    except:
        input_points = k_set
    input_time = random.uniform(keys[1], shape=(N_t, 1), minval=0, maxval=T)
    input_time=input_time.at[-1].set(T)
    ob_b=[]
    for TT in input_time:
        c_shifted = chebyshev_shift_sparse(k_full, u0_nd_cheb, TT[0])
        u_nd_cheb_new = np.array(list(c_shifted.values()), dtype=np.float32)
        ob_b.append(diff_matrices_b@u_nd_cheb_new)
    ob_b=np.vstack(ob_b)

    print('done! Training')
    errors = []
    error_min = 10
    T1 = time.time()
    for j in range(ite * N_epochs):
        loss, model, opt_state = make_step(model, input_points,input_time, ob_0, ob_b, matrices, optim, opt_state)
        if j % N_epochs == 0:
            keys = random.split(keys[-1], 3)
            try:
                input_points = random.choice(keys[0], k_set, shape=(N_interior,), replace=False, p=k_normalized)
            except:
                input_points = k_set
            input_time = random.uniform(keys[1], shape=(N_t, 1), minval=0, maxval=T)
            ob_b = []
            for TT in input_time:
                c_shifted = chebyshev_shift_sparse(k_full, u0_nd_cheb, TT[0])
                u_nd_cheb_new = np.array(list(c_shifted.values()), dtype=np.float32)
                ob_b.append(diff_matrices_b @ u_nd_cheb_new)
            ob_b = np.vstack(ob_b)

            # valid
            c_pred = vmap(model, (0,))(k_test_set)
            error_c = np.linalg.norm(c_pred.squeeze()-u_nd_cheb_T.squeeze())/np.linalg.norm(u_nd_cheb_T)
            u_pred = vmap(high_icheby, (None, None, 0))(c_pred.squeeze(), k_test_set[:,:-1], x_test)
            relative_error = jnp.linalg.norm(u_pred.flatten() - u_test.flatten()) / jnp.linalg.norm(u_test.flatten())
            errors.append(relative_error)
            print(f'epochs:{int(j/N_epochs)}, error_u: {relative_error:.2e}, loss:{loss:.2e}, error_c:{error_c:.2e}')
            print('++++++++++++++++++++++++')
            if relative_error < error_min:
                path = f'results/convection_cheby_{args.dim}_{args.seed}_{args.percentage}.eqx'
                eqx.tree_serialise_leaves(path, model)
                path = f'results/convection_cheby_{args.dim}_{args.seed}_{args.percentage}.npz'
                np.savez(path, u_pred=u_pred, u_test=u_test, c_pred=c_pred,errors=errors)
                error_min = relative_error
    T2 = time.time()
    execution_time = T2 - T1
    c_pred = vmap(model, (0,))(k_test_set)
    error_c = np.linalg.norm(c_pred.squeeze() - u_nd_cheb_T.squeeze()) / np.linalg.norm(u_nd_cheb_T)
    u_pred = vmap(high_icheby, (None, None, 0))(c_pred.squeeze(), k_test_set[:, :-1], x_test)
    relative_error = jnp.linalg.norm(u_pred.flatten() - u_test.flatten()) / jnp.linalg.norm(u_test.flatten())
    errors.append(relative_error)
    print(f'epochs:{int(j / N_epochs)}, error_u: {relative_error:.2e}, loss:{loss:.2e}, error_c:{error_c:.2e}')

    # save model and results
    if relative_error < error_min:
        path = f'results/convection_cheby_{args.dim}_{args.seed}_{args.percentage}.eqx'
        eqx.tree_serialise_leaves(path, model)
        path = f'results/convection_cheby_{args.dim}_{args.seed}_{args.percentage}.npz'
        np.savez(path, u_pred=u_pred, u_test=u_test, c_pred=c_pred,errors=errors)
        error_min = relative_error
    print(f'finial result u: {error_min:.2e},time:{execution_time:.6f}')

if __name__ == "__main__":
    seed = args.seed
    np.random.seed(seed)
    key = random.PRNGKey(seed)
    train(key)

