# Imports
import sys
sys.path.append('../../../')
sys.path.append('../../')
sys.path.append('../../../src')
from swimpde import Domain
from swimpde import BasicAnsatz
from swimpde import Diffusion_Solver
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
import matplotlib.cm as cm
import time
cmap = cm.jet
from examples import utils
from mpl_toolkits.mplot3d import Axes3D
from utils import *

name = 'swim_projection'
sys.stdout = open(name + '_3d.txt','wt')

# Train and test boundary points
# Set seeds
np.random.seed(2)
rng = np.random.default_rng(seed=123)
n_b_train = 4000  # Number of samples
n_b_test = 2000  # Number of samples
n_int_train = 16000  # Number of samples
n_int_test = 8000  # Number of samples

# initial condition
def u0(x):
    space_dim = np.shape(x)[1]
    return np.cos(np.sum(x, axis=1)/space_dim)

# forcing
def forcing(x, t):    
    space_dim_inv = 1./np.shape(x)[1]
    return np.exp(-t) * np.cos(np.sum(x, axis=1) * space_dim_inv) * (space_dim_inv - 1)

# boundary condition
boundary_condition = "dirichlet"

# Analytical solution
def analytical_sol(x, t):
    return np.exp(-t) * np.cos(np.sum(x, axis=1)/np.shape(x)[1]) 

def du_dt(x, t):
    return -np.exp(-t) * np.cos(np.sum(x, axis=1)/np.shape(x)[1]) 

# Visualize the true solution
timesteps = [0, 30, 60, 99]

def gram_schmidt(B):
    """
    Orthonormalizes two basis vectors using the Gram-Schmidt process.
    
    Parameters:
    B : ndarray of shape (N, 2, 5) - Two basis vectors for each sample
    
    Returns:
    B_orth : ndarray of shape (N, 2, 5) - Orthonormalized basis vectors
    """
    u1 = B[:, 0]  # First vector (N, 5)
    u1 /= np.linalg.norm(u1, axis=1, keepdims=True)  # Normalize

    u2 = B[:, 1] - np.einsum('ij,ij->i', B[:, 1], u1)[:, np.newaxis] * u1  # Remove projection on u1
    #print('u2', u2)
    #u2 /= np.linalg.norm(u2, axis=1, keepdims=True)  # Normalize

    return np.stack((u1, u2), axis=1)  # Stack orthonormalized vectors

def project_on_gradient_plane(X1, X2, grad_f):
    """
    Projects displacement vectors onto the 2D plane spanned by the function's gradients at X1 and X2.

    Parameters:
    X1 : ndarray of shape (N, 5)  - First set of points (each row is a 5D point)
    X2 : ndarray of shape (N, 5)  - Second set of points
    grad_f : function handle      - Function that returns the gradient at a given point (supports batch input)

    Returns:
    v_proj : ndarray of shape (N, 5) - Projected vectors onto the 2D gradient plane
    """
    G1 = grad_f(X1)  # Gradient at first point (N, 5)
    G2 = grad_f(X2)  # Gradient at second point (N, 5)

    # Use the two gradients as the basis for the 2D plane
    B = np.stack((G1, G2), axis=1)  # Shape: (N, 2, 5)

    # Orthonormalize the basis vectors
    B_orth = gram_schmidt(B)  # Shape: (N, 2, 5)

    # Compute displacement vectors
    V_ij = X2 - X1  # Shape: (N, 5)

    # Project displacement vectors onto the 2D plane
    coeffs = np.einsum('nij,nj->ni', B_orth, V_ij)  # Projection coefficients (N, 2)
    v_proj = np.einsum('ni,nij->nj', coeffs, B_orth)  # Reconstruct projected vectors (N, 5)

    return v_proj

# Example batch gradient function
def gradient_u0(x):
    space_dim = np.shape(x)[1]
    s = np.sum(x, axis=1) / space_dim  # shape: (N,)
    grad = -np.sin(s)[:, np.newaxis] / space_dim  # shape: (N, 1) → broadcast
    return np.ones_like(x) * grad  # shape: (N, d)

"""
def gradient_u0(x):
    grad = np.ones_like(x)
    d = np.shape(x)[1]
    for i in np.arange(d):
        grad[:, i] = -(1./d) * np.sin(np.sum(x, axis=1)/d) 
    return grad
"""

def sample_parameters(x, y, rng):
        """
        Sample directions from points to other points in the given dataset (x, y).
        """

        # n_repetitions repeats the sampling procedure to find better directions.
        # If we require more samples than data points, the repetitions will cause more pairs to be drawn.
        n_repetitions = max(1, int(np.ceil(width/ x.shape[0]))) * 1

        # This guarantees that:
        # (a) we draw from all the N(N-1)/2 - N possible pairs (minus the exact idx_from=idx_to case)
        # (b) no indices appear twice at the same position (never idx0[k]==idx1[k] for all k)
        candidates_idx_from = rng.integers(low=0, high=x.shape[0], size=x.shape[0] * n_repetitions)
        delta = rng.integers(low=1, high=x.shape[0], size=candidates_idx_from.shape[0])
        candidates_idx_to = (candidates_idx_from + delta) % x.shape[0]
        #directions = x[candidates_idx_to, ...] - x[candidates_idx_from, ...]
        #directions[:, 2:] = np.zeros((np.shape(directions)[0], d-2))

        ##########################################################

        # directions = x[candidates_idx_to, ...] - x[candidates_idx_from, ...]
        X1 = x[candidates_idx_from, ...]
        X2 = x[candidates_idx_to, ...]
        #directions = X2 - X1
        #print('directions 1\n: ', X2 - X1)
        directions = project_on_gradient_plane(X1, X2, gradient_u0)
        #print('directions 2\n: ', directions)
        ##########################################################

        # Uncomment the following line to project the difference vectors on the 2-d space (knowing that the underlying function is 2-d)
        # directions[:, 2:] = np.zeros((np.shape(directions)[0], d-2))
        
        dists = np.linalg.norm(directions, axis=1, keepdims=True)
        dists = np.clip(dists, a_min=1e-10, a_max=None)
        directions = directions / dists
        # print('directions 2\n: ', directions)

        # TODO: Project the direction onto the gradient of the function.
        dy = y[candidates_idx_to, :] - y[candidates_idx_from, :]

        # We always sample with replacement to avoid forcing to sample low densities
        probabilities = weight_probabilities(dy, dists)
        #print('probabilities', probabilities)
        selected_idx = rng.choice(dists.shape[0], size=width, replace=True, p=probabilities)

        directions = directions[selected_idx]
        dists = dists[selected_idx]
        idx_from = candidates_idx_from[selected_idx]
        idx_to = candidates_idx_to[selected_idx]
        
        return directions, dists, idx_from, idx_to

def weight_probabilities(dy, dists, sample_uniformly=False):
        """Compute probability that a certain weight should be chosen as part of the network.
        This method computes all probabilities at once, without removing the new weights one by one.

        Args:
            dy: function difference
            dists: distance between the base points
            rng: random number generator

        Returns:
            probabilities: probabilities for the weights.
        """
        # compute the maximum over all changes in all y directions to sample good gradients for all outputs
        gradients = (np.max(np.abs(dy), axis=1, keepdims=True) / dists).ravel()

        if sample_uniformly or np.sum(gradients) < 1e-10:
            # When all gradients are small, avoind dividing by a small number
            # and default to uniform distribution.
            probabilities = np.ones_like(gradients) / len(gradients)
        else:
            probabilities = gradients / np.sum(gradients)

        return probabilities


def sample_parameters_tanh(x, y, rng):
        scale = 0.5 * (np.log(1 + 1/2) - np.log(1 - 1/2))
        directions, dists, idx_from, idx_to = sample_parameters(x, y, rng)
        weights = (2 * scale * directions / dists).T
        biases = -np.sum(x[idx_from, :] * weights.T, axis=-1).reshape(1, -1) - scale
        return weights, biases, idx_from, idx_to


seeds = [1, 2]
experiments = []
width = 1000 # Width
reg_const = 1e-10 # Regularization constant
svd_cutoff = 1e-10
rtol = 1e-5
atol = 1e-5

# Loop over different seeds
rmse_swim = np.ones((len(seeds), ))
rel_err_swim = np.ones((len(seeds), ))
time_swim = np.ones((len(seeds), ))
info = []
dimensions = [3] # , 5, 7, 10

for d in dimensions:
    # Train and test boundary points
    X_b_train, boundary_labels = utils.sample_boundary_lhs(d, n_b_train, bounds=(-1,1))
    X_b_test, boundary_labels_test = utils.sample_boundary_lhs(d, n_b_test, bounds=(-1,1))

    # Train and test interior points
    X_int_train = utils.sample_interior_lhs(d, n_int_train, bounds=(-1,1))
    X_int_test = utils.sample_interior_lhs(d, n_int_test, bounds=(-1,1))
    X_test = np.vstack((X_int_test, X_b_test))
    t_eval = np.linspace(0, 1, 100).reshape(-1, 1, 1) # time domain
    x_train = X_int_train # space domain

    u_true =  analytical_sol(X_test, t_eval)
    u_true = np.reshape(u_true, (np.shape(u_true)[0], np.shape(u_true)[2]))
    j = 0
    for seed in seeds:
        ansatz_swim = BasicAnsatz(
            n_neurons=width,
            activation="tanh",
            random_state=seed,
            regularization_scale=reg_const,
            parameter_sampler = sample_parameters_tanh
        )   
        # Interior points
        normal_vectors = X_b_train.copy()
        
        # Domain
        domain = Domain(
            interior_points=X_int_train,
            boundary_points=X_b_train,
            normal_vectors=normal_vectors
        )
        
        diffusion_solver_swim = Diffusion_Solver(
            domain=domain, 
            ansatz=ansatz_swim,
            u0=u0,
            boundary_condition=boundary_condition,
            forcing=forcing,
            regularization_scale=reg_const,
            scale_boundary_correction=1000.,
            boundary_condition_true=analytical_sol,
            ode_solver='LSODA'                                
        )
        
        # Compute weights and biases of the SWIM network
        time_blocks = 1
        ic_eval = u0(domain.interior_points)
        t_swim_start = time.time()
        sol_swim, solver_status_swim = diffusion_solver_swim.fit(t_span=[0, np.max(t_eval)], 
                                                rtol = rtol, atol = atol, svd_cutoff= svd_cutoff,
                                                outer_basis=False,
                                                init_cond=ic_eval,
                                                du_dt = du_dt);
        t_swim_stop = time.time()
        time_swim[j] = t_swim_stop - t_swim_start

        # Evaluate on test data
        u_swim_test = diffusion_solver_swim.evaluate(x_eval=X_test, t_eval = t_eval).T #, solver_status=solver_status
                    
        # Compute metrics
        rmse_swim[j] = np.sqrt(mean_squared_error(u_true, u_swim_test))  # mean squared error
        rel_err_swim[j] = rmse_swim[j]/np.sqrt(mean_squared_error(u_true, np.zeros_like(u_true)))

        # Compute metrics
        rmse_swim[j] = np.sqrt(mean_squared_error(u_true, u_swim_test))
        rel_err_swim[j] = rmse_swim[j]/np.sqrt(mean_squared_error(u_true, np.zeros_like(u_true)))
        info.append(time_swim[j])
        info.append(rmse_swim[j])
        #print('time=', time_swim[j], 'rmse_swim=', rmse_swim[j], 'rel_err_swim=',rel_err_swim[j])
        j += 1
    print('dimensions:', d)
    print('swim-ode time = ', np.mean(time_swim))
    print('rmse swim-ode = ',np.mean(rmse_swim), '+-', np.std(rmse_swim))
    print('rel l-2 error swim-ode = ',np.mean(rel_err_swim), '+-', np.std(rel_err_swim))
    experiments.append(info)




