
import sys
sys.path.append('../../../')
sys.path.append('../../../src/')
from swimpde import Domain
from swimpde import BasicAnsatz
from swimpde import Diffusion_Solver
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
import matplotlib.cm as cm
import time
cmap = cm.jet
from examples.utils import *
from mpl_toolkits.mplot3d import Axes3D
#from exautils import *

# Set seeds
np.random.seed(2)
rng = np.random.default_rng(seed=123)
print(sys.path)

name = '100d_diffusion'
sys.stdout = open(name + '.txt','wt')

# Train and test boundary points
d = 100  # Dimensions
n_b_test = 2000  # Number of samples
n_int_test = 8000  # 16000 Number of samples
reg_const = 1e-8 # Regularization constant
svd_cutoff = 1e-8
rtol = 1e-6
atol = 1e-6

def sampler_1(x, _, rng):
    #weights = rng.normal(loc=0, scale=0.05, size=(x.shape[1], width))
    r_m = 0.005
    weights = rng.uniform(low=-1.*r_m, high=r_m, size=(x.shape[1], width)) # low=-np.pi, high=np.pi,  2 * np.pi
    biases = rng.uniform(low=-1.*r_m, high=r_m, size=(1, width)) # low=-np.pi, high=np.pi,  2 * np.pi
    idx0 = None
    idx1 = None
    return weights, biases, idx0, idx1

def sampler_2(x, _, rng):
    #weights = rng.normal(loc=0, scale=0.05, size=(x.shape[1], width))
    r_m = 0.05
    weights = rng.uniform(low=-1.*r_m, high=r_m, size=(x.shape[1], width)) # low=-np.pi, high=np.pi,  2 * np.pi
    biases = rng.uniform(low=-1.*r_m, high=r_m, size=(1, width)) # low=-np.pi, high=np.pi,  2 * np.pi
    idx0 = None
    idx1 = None
    return weights, biases, idx0, idx1

#####################################################################################################
def gram_schmidt(B):
    """
    Orthonormalizes two basis vectors using the Gram-Schmidt process.
    
    Parameters:
    B : ndarray of shape (N, 2, 5) - Two basis vectors for each sample
    
    Returns:
    B_orth : ndarray of shape (N, 2, 5) - Orthonormalized basis vectors
    """
    u1 = B[:, 0]  # First vector (N, 5)
    u1 /= np.linalg.norm(u1, axis=1, keepdims=True)  # Normalize

    u2 = B[:, 1] - np.einsum('ij,ij->i', B[:, 1], u1)[:, np.newaxis] * u1  # Remove projection on u1
    u2 /= np.linalg.norm(u2, axis=1, keepdims=True)  # Normalize

    return np.stack((u1, u2), axis=1)  # Stack orthonormalized vectors

def project_on_gradient_plane(X1, X2, grad_f):
    """
    Projects displacement vectors onto the 2D plane spanned by the function's gradients at X1 and X2.

    Parameters:
    X1 : ndarray of shape (N, 5)  - First set of points (each row is a 5D point)
    X2 : ndarray of shape (N, 5)  - Second set of points
    grad_f : function handle      - Function that returns the gradient at a given point (supports batch input)

    Returns:
    v_proj : ndarray of shape (N, 5) - Projected vectors onto the 2D gradient plane
    """
    G1 = grad_f(X1)  # Gradient at first point (N, 5)
    G2 = grad_f(X2)  # Gradient at second point (N, 5)

    # Use the two gradients as the basis for the 2D plane
    B = np.stack((G1, G2), axis=1)  # Shape: (N, 2, 5)

    # Orthonormalize the basis vectors
    B_orth = gram_schmidt(B)  # Shape: (N, 2, 5)

    # Compute displacement vectors
    V_ij = X2 - X1  # Shape: (N, 5)

    # Project displacement vectors onto the 2D plane
    coeffs = np.einsum('nij,nj->ni', B_orth, V_ij)  # Projection coefficients (N, 2)
    v_proj = np.einsum('ni,nij->nj', coeffs, B_orth)  # Reconstruct projected vectors (N, 5)

    return v_proj

# Example batch gradient function
def gradient_u0(x):
    return 0.01 * x 

def sample_parameters(x, y, rng):
        """
        Sample directions from points to other points in the given dataset (x, y).
        """

        # n_repetitions repeats the sampling procedure to find better directions.
        # If we require more samples than data points, the repetitions will cause more pairs to be drawn.
        n_repetitions = max(1, int(np.ceil(width/ x.shape[0]))) * 1

        # This guarantees that:
        # (a) we draw from all the N(N-1)/2 - N possible pairs (minus the exact idx_from=idx_to case)
        # (b) no indices appear twice at the same position (never idx0[k]==idx1[k] for all k)
        candidates_idx_from = rng.integers(low=0, high=x.shape[0], size=x.shape[0] * n_repetitions)
        delta = rng.integers(low=1, high=x.shape[0], size=candidates_idx_from.shape[0])
        candidates_idx_to = (candidates_idx_from + delta) % x.shape[0]
        #directions = x[candidates_idx_to, ...] - x[candidates_idx_from, ...]
        #directions[:, 2:] = np.zeros((np.shape(directions)[0], d-2))

        ##########################################################

        # directions = x[candidates_idx_to, ...] - x[candidates_idx_from, ...]
        X1 = x[candidates_idx_from, ...]
        X2 = x[candidates_idx_to, ...]
        #directions = X2 - X1
        #print('directions 1\n: ', X2 - X1)
        directions = project_on_gradient_plane(X1, X2, gradient_u0)
        #print('directions 2\n: ', directions)
        ##########################################################

        # Uncomment the following line to project the difference vectors on the 2-d space (knowing that the underlying function is 2-d)
        # directions[:, 2:] = np.zeros((np.shape(directions)[0], d-2))
        
        dists = np.linalg.norm(directions, axis=1, keepdims=True)
        dists = np.clip(dists, a_min=1e-10, a_max=None)
        directions = directions / dists
        # print('directions 2\n: ', directions)

        # TODO: Project the direction onto the gradient of the function.
        dy = y[candidates_idx_to, :] - y[candidates_idx_from, :]

        # We always sample with replacement to avoid forcing to sample low densities
        probabilities = weight_probabilities(dy, dists)
        selected_idx = rng.choice(dists.shape[0], size=width, replace=True, p=probabilities)

        directions = directions[selected_idx]
        dists = dists[selected_idx]
        idx_from = candidates_idx_from[selected_idx]
        idx_to = candidates_idx_to[selected_idx]
        
        return directions, dists, idx_from, idx_to

def weight_probabilities(dy, dists, sample_uniformly=False):
        """Compute probability that a certain weight should be chosen as part of the network.
        This method computes all probabilities at once, without removing the new weights one by one.

        Args:
            dy: function difference
            dists: distance between the base points
            rng: random number generator

        Returns:
            probabilities: probabilities for the weights.
        """
        # compute the maximum over all changes in all y directions to sample good gradients for all outputs
        gradients = (np.max(np.abs(dy), axis=1, keepdims=True) / dists).ravel()

        if sample_uniformly or np.sum(gradients) < 1e-10:
            # When all gradients are small, avoind dividing by a small number
            # and default to uniform distribution.
            probabilities = np.ones_like(gradients) / len(gradients)
        else:
            probabilities = gradients / np.sum(gradients)

        return probabilities


def sample_parameters_tanh(x, y, rng):
        scale = 0.5 * (np.log(1 + 1/2) - np.log(1 - 1/2))

        directions, dists, idx_from, idx_to = sample_parameters(x, y, rng)
        weights = (2 * scale * directions / dists).T
        biases = -np.sum(x[idx_from, :] * weights.T, axis=-1).reshape(1, -1) - scale

        return weights, biases, idx_from, idx_to

#####################################################################################################
experiments = []
n_int_train_list = [500, 1000, 2000, 5000]
seeds = [1,2,3] #[1, 2, 3]
experiments = []
widths = [200, 400, 600, 800] #[300, 500, 800] #[2000, 3000, 4000, 5000, 8000]
reg_consts = [1e-4, 1e-6, 1e-8, 1e-10] 
tols = [1e-3, 1e-4, 1e-6]
param_samplers = [sample_parameters_tanh] # sampler_2, 'tanh'

# Test points
X_int_test = sample_interior_lhs_ball(d, n_int_test)
X_b_test = sample_boundary_lhs_ball(d, n_b_test)
X_test = np.vstack((X_int_test, X_b_test))

# initial condition
def u0(x):
    space_dim = np.shape(x)[1]
    return np.sum(x**2, axis=1, keepdims=True)/(2. * space_dim)

# forcing
def forcing(x, t):    
    return 0.

# boundary condition
boundary_condition = "dirichlet"

# Analytical solution
def analytical_sol(x, t):
    space_dim = np.shape(x)[1]
    return t + np.sum(x**2, axis=1, keepdims=True)/(2. * space_dim)
    
# Test data
t_eval = np.linspace(0, 1, 100).reshape(-1, 1, 1) # time domain
u_true =  analytical_sol(X_test, t_eval)
u_true = np.reshape(u_true, (np.shape(u_true)[0], np.shape(u_true)[1]))

# Loop over different seeds
info = []
experiments = []
lambda_b = 1e4
for width in widths:
    for reg_const in reg_consts:
        for tol in tols:
            rtol=tol
            atol=tol
            for param_sampler in param_samplers:
                for n_int_train in n_int_train_list:
                    n_b_train = n_int_train

                    # Initialize all arrays with various metrics
                    rmse_elm = 100 * np.ones((len(seeds), ))
                    rel_err_elm = 100 * np.ones((len(seeds)))
                    rmse_elm_train = 100 * np.ones((len(seeds)))
                    rel_err_elm_train = 100 * np.ones((len(seeds)))
                    time_elm = 100 * np.ones((len(seeds)))
                    j = 0                    
                    for seed in seeds:
                        # Set seeds
                        np.random.seed(2)
                        rng = np.random.default_rng(seed=123)
                        ansatz_elm = BasicAnsatz(
                            n_neurons=width,
                            activation="tanh",
                            random_state=seed,
                            regularization_scale=reg_const,
                            parameter_sampler = param_sampler
                        )  
                        X_int_train = sample_interior_lhs_ball(d, n_int_train)
                        X_b_train = sample_boundary_lhs_ball(d, n_b_train)

                        # Interior points
                        normal_vectors = X_b_train.copy()
                        
                        # Domain
                        domain = Domain(
                            interior_points=X_int_train,
                            boundary_points=X_b_train,
                            normal_vectors=normal_vectors
                            )
                        ic_eval = u0(domain.interior_points)

                        diffusion_solver_elm = Diffusion_Solver(
                            domain=domain, 
                            ansatz=ansatz_elm,
                            u0=u0,
                            boundary_condition=boundary_condition,
                            forcing=forcing,
                            regularization_scale=reg_const,
                            scale_boundary_correction=lambda_b,
                            boundary_condition_true=analytical_sol
                        )
                        # Compute weights and biases of the elm network
                        time_blocks = 1
                        t_elm_start = time.time()
                        sol_elm, solver_status_elm = diffusion_solver_elm.fit(t_span=[0, np.max(t_eval)], 
                                                    rtol = rtol, atol = atol, svd_cutoff= svd_cutoff,
                                                    outer_basis=False,
                                                    init_cond=ic_eval);
                        t_elm_stop = time.time()
                        time_elm[j] = t_elm_stop - t_elm_start

                        # Evaluate on test data
                        u_elm_test = diffusion_solver_elm.evaluate(x_eval=X_test, t_eval = t_eval).T #, solver_status=solver_status

                        # Compute metrics
                        rmse_elm[j] = np.sqrt(mean_squared_error(u_true, u_elm_test))
                        rel_err_elm[j] = rmse_elm[j]/np.sqrt(mean_squared_error(u_true, np.zeros_like(u_true)))
                        info.append(time_elm[j])
                        info.append(rmse_elm[j])
                        print('time=', time_elm[j], 'rmse_elm=', rmse_elm[j], 'rel_err_elm=',rel_err_elm[j])
                        j = j + 1

                    # Train 
                    print('-------------------------------------------------------------------------')
                    print('Width: ', width, 'param_sampler: ', param_sampler, 'reg_const', reg_const, 'atol', atol)
                    print('Boundary scaling: ', lambda_b, 'n_int_train: ', n_int_train)
                    
                    print('-------------------------------------------------------------------------')
                    print('Train: elm-ode time = ', np.mean(time_elm))
                    print('Test: rmse elm-ode = ',np.mean(rmse_elm), '+-', np.std(rmse_elm))
                    print('Test: rel l-2 error elm-ode = ',np.mean(rel_err_elm), '+-', np.std(rel_err_elm))
                    print('-------------------------------------------------------------------------\n')
                    experiments.append(info)
