# Imports
import sys
sys.path.append('../../../')
sys.path.append('../../../src')
from swimpde import Domain
from swimpde import BasicAnsatz
from swimpde import Reaction_Diffusion_Solver
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
import matplotlib.cm as cm
import time
cmap = cm.jet
from examples.utils import *
from mpl_toolkits.mplot3d import Axes3D

# Set seeds
np.random.seed(2)
rng = np.random.default_rng(seed=123)

name = '10d_verify'
sys.stdout = open(name + '.txt','wt')

# Train and test boundary points
d = 10  # Dimensions
n_b_train = 1000 
n_b_test = 1000 

n_int_train_list = [1000, 2000, 20000]
seeds = [1,2,3]
experiments = []
widths = [100, 400, 700]
reg_consts = [1e-12]
scale_bdry = [100, 1000, 10000] 

# Train and test boundary points
X_b_train, boundary_labels = sample_boundary_lhs(d, n_b_train, bounds=(-1,1))
X_b_test, boundary_labels_test = sample_boundary_lhs(d, n_b_test, bounds=(-1,1))
             
# Test interioir points
x_0 = np.linspace(-1, 1, 100, endpoint=True)
x_1 = np.linspace(-1, 1, 100, endpoint=True)
xx, yy = np.meshgrid(x_0, x_1)
np.random.seed(2)
rng = np.random.default_rng(seed=123)
X_test = rng.uniform(low=-1, high=1, size=(100 * 100, d)) 
X_test[:, 0] = xx.reshape(-1)
X_test[:, 1] = yy.reshape(-1)

import math
# initial condition
def u0(x):
    return 2. * np.sin(0.5 * np.pi * x[:, 0]) * np.cos(0.5 * np.pi * x[:, 1])

# forcing
def forcing(x, t):    
    return (np.pi**2. - 2.) * np.exp(-t) * np.sin(0.5 * np.pi * x[:, 0]) * np.cos(0.5 * np.pi * x[:, 1]) - 4. * np.exp(-2. * t) * ((np.sin(0.5 * np.pi * x[:, 0]))**2.) * (np.cos(0.5 * np.pi* x[:, 1])**2.)

# boundary condition
boundary_condition = "dirichlet"

# Analytical solution
def analytical_sol(x, t):
    return 2. * np.sin(0.5 * np.pi * x[:, 0]) * np.cos(0.5 * np.pi * x[:, 1]) * np.exp(-t)

# Test data
t_eval = np.linspace(0, 1, 100).reshape(-1, 1, 1) # time domain
t_eval_test = np.array([1]).reshape(-1, 1, 1) # time domain

u_true =  analytical_sol(X_test, t_eval_test)
u_true = np.reshape(u_true, (np.shape(u_true)[0], np.shape(u_true)[2]))

u_true_b_train =  analytical_sol(X_b_train, t_eval_test)
u_true_b_train = np.reshape(u_true_b_train, (np.shape(u_true_b_train)[0], np.shape(u_true_b_train)[2]))

u_true_b_test =  analytical_sol(X_b_test, t_eval_test)
u_true_b_test = np.reshape(u_true_b_test, (np.shape(u_true_b_test)[0], np.shape(u_true_b_test)[2]))

# Visualize the true solution
timesteps = [0, 30, 60, 99]

def gram_schmidt(B):
    """
    Orthonormalizes two basis vectors using the Gram-Schmidt process.
    
    Parameters:
    B : ndarray of shape (N, 2, 5) - Two basis vectors for each sample
    
    Returns:
    B_orth : ndarray of shape (N, 2, 5) - Orthonormalized basis vectors
    """
    u1 = B[:, 0]  # First vector (N, 5)
    u1 /= np.linalg.norm(u1, axis=1, keepdims=True)  # Normalize

    u2 = B[:, 1] - np.einsum('ij,ij->i', B[:, 1], u1)[:, np.newaxis] * u1  # Remove projection on u1
    u2 /= np.linalg.norm(u2, axis=1, keepdims=True)  # Normalize

    return np.stack((u1, u2), axis=1)  # Stack orthonormalized vectors

def project_on_gradient_plane(X1, X2, grad_f):
    """
    Projects displacement vectors onto the 2D plane spanned by the function's gradients at X1 and X2.

    Parameters:
    X1 : ndarray of shape (N, 5)  - First set of points (each row is a 5D point)
    X2 : ndarray of shape (N, 5)  - Second set of points
    grad_f : function handle      - Function that returns the gradient at a given point (supports batch input)

    Returns:
    v_proj : ndarray of shape (N, 5) - Projected vectors onto the 2D gradient plane
    """
    G1 = grad_f(X1)  # Gradient at first point (N, 5)
    G2 = grad_f(X2)  # Gradient at second point (N, 5)

    # Use the two gradients as the basis for the 2D plane
    B = np.stack((G1, G2), axis=1)  # Shape: (N, 2, 5)

    # Orthonormalize the basis vectors
    B_orth = gram_schmidt(B)  # Shape: (N, 2, 5)

    # Compute displacement vectors
    V_ij = X2 - X1  # Shape: (N, 5)

    # Project displacement vectors onto the 2D plane
    coeffs = np.einsum('nij,nj->ni', B_orth, V_ij)  # Projection coefficients (N, 2)
    v_proj = np.einsum('ni,nij->nj', coeffs, B_orth)  # Reconstruct projected vectors (N, 5)

    return v_proj

# Example batch gradient function
def gradient_u0(x):
    #d = np.shape(x)[1]
    grad = np.zeros_like(x)
    grad[:, 0] =  np.pi * np.cos(0.5 * np.pi * x[:, 0]) * np.cos(0.5 * np.pi * x[:, 1])
    grad[:, 1] =  - np.pi * np.sin(0.5 * np.pi * x[:, 0]) * np.sin(0.5 * np.pi * x[:, 1])
    return grad #2. * np.sin(0.5 * np.pi * x[:, 0]) * np.cos(0.5 * np.pi * x[:, 1])


# Generate random sample points
N = 3  # Number of pairs
X1 = np.random.randn(N, 5)  # First set of points
X2 = np.random.randn(N, 5)  # Second set of points

# Compute projected vectors
V_proj = project_on_gradient_plane(X1, X2, gradient_u0)

def sample_parameters(x, y, rng):
        """
        Sample directions from points to other points in the given dataset (x, y).
        """

        # n_repetitions repeats the sampling procedure to find better directions.
        # If we require more samples than data points, the repetitions will cause more pairs to be drawn.
        n_repetitions = max(1, int(np.ceil(width/ x.shape[0]))) * 1

        # This guarantees that:
        # (a) we draw from all the N(N-1)/2 - N possible pairs (minus the exact idx_from=idx_to case)
        # (b) no indices appear twice at the same position (never idx0[k]==idx1[k] for all k)
        candidates_idx_from = rng.integers(low=0, high=x.shape[0], size=x.shape[0] * n_repetitions)
        delta = rng.integers(low=1, high=x.shape[0], size=candidates_idx_from.shape[0])
        candidates_idx_to = (candidates_idx_from + delta) % x.shape[0]
        X1 = x[candidates_idx_from, ...]
        X2 = x[candidates_idx_to, ...]
        directions = project_on_gradient_plane(X1, X2, gradient_u0)
        dists = np.linalg.norm(directions, axis=1, keepdims=True)
        dists = np.clip(dists, a_min=1e-10, a_max=None)
        directions = directions / dists
        dy = y[candidates_idx_to, :] - y[candidates_idx_from, :]

        # We always sample with replacement to avoid forcing to sample low densities
        probabilities = weight_probabilities(dy, dists)
        selected_idx = rng.choice(dists.shape[0], size=width, replace=True, p=probabilities)

        directions = directions[selected_idx]
        dists = dists[selected_idx]
        idx_from = candidates_idx_from[selected_idx]
        idx_to = candidates_idx_to[selected_idx]
        
        return directions, dists, idx_from, idx_to

def weight_probabilities(dy, dists, sample_uniformly=False):
        """Compute probability that a certain weight should be chosen as part of the network.
        This method computes all probabilities at once, without removing the new weights one by one.

        Args:
            dy: function difference
            dists: distance between the base points
            rng: random number generator

        Returns:
            probabilities: probabilities for the weights.
        """
        # compute the maximum over all changes in all y directions to sample good gradients for all outputs
        gradients = (np.max(np.abs(dy), axis=1, keepdims=True) / dists).ravel()

        if sample_uniformly or np.sum(gradients) < 1e-10:
            # When all gradients are small, avoind dividing by a small number
            # and default to uniform distribution.
            probabilities = np.ones_like(gradients) / len(gradients)
        else:
            probabilities = gradients / np.sum(gradients)

        return probabilities


def sample_parameters_tanh(x, y, rng):
        scale = 0.5 * (np.log(1 + 1/2) - np.log(1 - 1/2))
        directions, dists, idx_from, idx_to = sample_parameters(x, y, rng)
        weights = (2 * scale * directions / dists).T
        biases = -np.sum(x[idx_from, :] * weights.T, axis=-1).reshape(1, -1) - scale
        return weights, biases, idx_from, idx_to

def sample_parameters_randomly(x, _, rng):
    #weights = rng.normal(loc=0, scale=0.05, size=(x.shape[1], width))
    weights = rng.normal(loc=0, scale=1, size=(x.shape[1], width))
    biases = rng.uniform(low=-2 * np.pi, high=2 * np.pi, size=(1, width))
    idx0 = None
    idx1 = None
    return weights, biases, idx0, idx1


info = []

def sample_parameters_randomly(x, _, rng):
        weights = rng.uniform(low=0, high=1, size=(x.shape[1], width)) # ,  2 * np.pi
        biases = rng.uniform(low=0, high=1, size=(1, width)) # low=-np.pi, high=np.pi,  2 * np.pi
        idx0 = None
        idx1 = None
        return weights, biases, idx0, idx1

param_samplers = [sample_parameters_tanh]
for width in widths:
    for reg_const in reg_consts:
        svd_cutoff = reg_const
        rtol = 1e6 * reg_const
        atol = 1e6 * reg_const
        for lambda_b in scale_bdry:
            for param_sampler in param_samplers:
                for n_int_train in n_int_train_list:
                    j = 0
                    X_int_train = sample_interior_lhs(d, n_int_train, bounds=(-1,1)) 
                    x_train = X_int_train # space domain

                    u_true_train =  analytical_sol(x_train, t_eval)
                    u_true_train = np.reshape(u_true_train, (np.shape(u_true_train)[0], np.shape(u_true_train)[2]))

                    # Initialize all arrays with various metrics
                    rmse_swim = np.ones((len(seeds), ))
                    rel_err_swim = np.ones((len(seeds)))
                    rmse_swim_train = np.ones((len(seeds)))
                    rel_err_swim_train = np.ones((len(seeds)))
                    rmse_swim_train_b = np.ones((len(seeds)))
                    rel_err_swim_train_b= np.ones((len(seeds)))
                    rmse_swim_test_b = np.ones((len(seeds)))
                    rel_err_swim_test_b= np.ones((len(seeds)))
                    time_swim = np.ones((len(seeds)))
                    for seed in seeds:
                        # Set seeds
                        np.random.seed(2)
                        rng = np.random.default_rng(seed=123)
                        # Parameter sampler for swim: Sample weights from a normal distribution and biases uniformly from [-4, 4]
                        ansatz_swim = BasicAnsatz(
                            n_neurons=width,
                            activation="tanh",
                            random_state=seed,
                            regularization_scale=reg_const,
                            parameter_sampler = param_sampler
                        )  
                        # Interior points
                        normal_vectors = X_b_train.copy()
                        
                        # Domain
                        domain = Domain(
                            interior_points=X_int_train,
                            boundary_points=X_b_train,
                            normal_vectors=normal_vectors,
                            sample_points = X_int_train
                        )
                        
                        reaction_diffusion_solver_swim = Reaction_Diffusion_Solver(
                            domain=domain, 
                            ansatz=ansatz_swim,
                            u0=u0,
                            boundary_condition=boundary_condition,
                            forcing=forcing,
                            regularization_scale=reg_const,
                            scale_boundary_correction=lambda_b,
                            boundary_condition_true=analytical_sol
                        )
                        # Compute weights and biases of the swim network
                        time_blocks = 1
                        ic_eval = u0(domain.all_points)
                        t_swim_start = time.time()
                        
                        sol_swim, solver_status_swim = reaction_diffusion_solver_swim.fit(t_span=[0, np.max(t_eval)], 
                                                                rtol = rtol, atol = atol, svd_cutoff= svd_cutoff,
                                                                outer_basis=False,
                                                                init_cond=ic_eval);
                        t_swim_stop = time.time()
                        time_swim[j] = t_swim_stop - t_swim_start

                        # Evaluate on test data
                        u_swim_test = reaction_diffusion_solver_swim.evaluate(x_eval=X_test, t_eval = t_eval_test).T #, solver_status=solver_status
                        u_swim_train = reaction_diffusion_solver_swim.evaluate(x_eval=x_train, t_eval = t_eval).T #, solver_status=solver_status
                        u_swim_boundary_train = reaction_diffusion_solver_swim.evaluate(x_eval=X_b_train, t_eval = t_eval_test).T #, solver_status=solver_status
                        u_swim_boundary_test = reaction_diffusion_solver_swim.evaluate(x_eval=X_b_test, t_eval = t_eval_test).T #, solver_status=solver_status
                                    
                        # Compute metrics
                        rmse_swim[j] = np.sqrt(mean_squared_error(u_true, u_swim_test))  # mean squared error
                        rel_err_swim[j] = rmse_swim[j]/np.sqrt(mean_squared_error(u_true, np.zeros_like(u_true)))
                        
                        rmse_swim_train[j] = np.sqrt(mean_squared_error(u_true_train, u_swim_train))  # mean squared error
                        rel_err_swim_train[j] = rmse_swim_train[j]/np.sqrt(mean_squared_error(u_true_train, np.zeros_like(u_true_train)))

                        rmse_swim_train_b[j] = np.sqrt(mean_squared_error(u_true_b_train, u_swim_boundary_train))  # mean squared error
                        rel_err_swim_train_b[j] = rmse_swim_train_b[j]/np.sqrt(mean_squared_error(u_true_b_train, np.zeros_like(u_true_b_train)))

                        rmse_swim_test_b[j] = np.sqrt(mean_squared_error(u_true_b_test, u_swim_boundary_test))  # mean squared error
                        rel_err_swim_test_b[j] = rmse_swim_test_b[j]/np.sqrt(mean_squared_error(u_true_b_test, np.zeros_like(u_true_b_test)))

                        # Compute metrics
                        info.append(time_swim[j])
                        info.append(rmse_swim[j])
                        print('time=', time_swim[j], 'rmse_swim=', rmse_swim[j], 'rel_err_swim=',rel_err_swim[j])
                        j += 1

                    # Train 
                    print('-------------------------------------------------------------------------')
                    print('Width: ', width, 'param_sampler: ', param_sampler, 'reg_const', reg_const, 'atol', atol)
                    print('Boundary scaling: ', lambda_b, 'n_int_train: ', n_int_train)
                    
                    print('-------------------------------------------------------------------------')
                    print('Train: Frozen-pinn-swim time = ', np.mean(time_swim))
                    print('Train: rmse Frozen-pinn-swim = ',np.mean(rmse_swim_train), '+-', np.std(rmse_swim_train))
                    print('Train: rel l-2 error Frozen-pinn-swim = ',np.mean(rel_err_swim_train), '+-', np.std(rel_err_swim_train))
                    print('Train: rel l-2 error Frozen-pinn-swim (boundary) = ',np.mean(rel_err_swim_train_b), '+-', np.std(rel_err_swim_train_b))

                    # Test
                    print('Test: rmse Frozen-pinn-swim = ',np.mean(rmse_swim), '+-', np.std(rmse_swim))
                    print('Test: rel l-2 error Frozen-pinn-swim = ',np.mean(rel_err_swim), '+-', np.std(rel_err_swim))
                    print('Test: rel l-2 error Frozen-pinn-swim (boundary) = ',np.mean(rel_err_swim_test_b), '+-', np.std(rel_err_swim_test_b))
                    print('-------------------------------------------------------------------------\n')
                    experiments.append(info)

