import sys
sys.path.append('../../../../')
sys.path.append('../../../../src')
from swimpde import Domain
from swimpde import BasicAnsatz
from swimpde import Diffusion_Solver
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
import matplotlib.cm as cm
import time
cmap = cm.jet
from examples.utils import *
from mpl_toolkits.mplot3d import Axes3D
#from exautils import *

# %% [markdown]
# ### Problem Setup

# %%
# List of all hyper-parameters
ratio = 5 # no. of data points/width 
reg_svd_constants = [1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-12] #[1e-6, 1e-9, 1e-12] #, 1e-8, 1e-10, 1e-12
dimensions = [10] # Spatial dimensions
widths = [4000] # Hidden layer widths
rms = [0.05, 0.005, 0.5, 0.1, 0.2,  1.] #[0.05, 0.005, 0.1] #
rtol = 1e-6 # Relative tolerance for the ODE solver
atol = 1e-6 # Absolute tolerance for the ODE solver
filename = f"../data_0/cs_svd_10.npy"
#reg_const = 1e-12 
#reg_svd_constants = [1e-6, 1e-8, 1e-10] 
#dimensions = [5] # Spatial dimensions
#widths = [500, 1000, 2000, 4000] # Hidden layer widths
#rms = [0.05]
#rms = [0.05, 0.005, 0.1] #

# initial condition
def u0(x):
    space_dim = np.shape(x)[1]
    return np.cos(np.sum(x, axis=1)/space_dim)

# forcing
def forcing(x, t):    
    space_dim_inv = 1./np.shape(x)[1]
    return np.exp(-t) * np.cos(np.sum(x, axis=1) * space_dim_inv) * (space_dim_inv - 1)

# boundary condition
boundary_condition = "dirichlet"

# Analytical solution
def analytical_sol(x, t):
    return np.exp(-t) * np.cos(np.sum(x, axis=1)/np.shape(x)[1]) 

# Test data
n_int_test = 8000  # Internal points
n_b_test = 2000  # Boundary points

# %% [markdown]
# ### Problem setup

# %%

experiments = []
for d in dimensions:   
    for w in widths:
        for reg_svd in reg_svd_constants:
            # Set seeds
            np.random.seed(2)
            rng = np.random.default_rng(seed=123)

            seed = 2
            info = []
            reg_const = reg_svd
            svd_cutoff = reg_svd
            
            # Train: interior and boundary points
            n_training = ratio * w
            n_b_train = w # 20 % of total training points 
            n_int_train = 4 * w   # 80% of total training points 
            X_b_train, boundary_labels = sample_boundary_lhs(d, n_b_train, bounds=(-1,1))
            X_int_train = sample_interior_lhs(d, n_int_train, bounds=(-1,1))

            # Test data
            X_b_test, boundary_labels_test = sample_boundary_lhs(d, n_b_test, bounds=(-1,1))
            X_int_test = sample_interior_lhs(d, n_int_test, bounds=(-1,1))
            X_test = np.vstack((X_int_test, X_b_test)) # Inernal + boundary points
            t_eval = np.linspace(0, 1, 100).reshape(-1, 1, 1) # time domain

            # Ground truth test data
            u_true =  analytical_sol(X_test, t_eval)
            u_true = np.reshape(u_true, (np.shape(u_true)[0], np.shape(u_true)[2]))


                            
            ansatz_swim = BasicAnsatz(
                n_neurons=w,
                activation="tanh",
                random_state=seed,
                regularization_scale=reg_const,
                parameter_sampler = "tanh" 
            )   
            # Interior points
            normal_vectors = X_b_train.copy()
            
            # Domain
            domain = Domain(
                interior_points=X_int_train,
                boundary_points=X_b_train,
                normal_vectors=normal_vectors
            )
            diffusion_solver_swim = Diffusion_Solver(
            domain=domain, 
            ansatz=ansatz_swim,
            u0=u0,
            boundary_condition=boundary_condition,
            forcing=forcing,
            regularization_scale=reg_const,
            scale_boundary_correction=1.,
            boundary_condition_true=analytical_sol
            )
            time_blocks = 1
            ic_eval = u0(domain.interior_points)

            # SWIM 
            t_swim_start = time.time()
            sol_swim, solver_status_swim = diffusion_solver_swim.fit(t_span=[0, np.max(t_eval)], 
                                                    rtol = rtol, atol = atol, svd_cutoff= svd_cutoff,
                                                    outer_basis=False,
                                                    init_cond=ic_eval);
            t_swim_stop = time.time()
            time_swim = t_swim_stop - t_swim_start

            # Evaluate on test data
            u_swim_test = diffusion_solver_swim.evaluate(x_eval=X_test, t_eval = t_eval).T #, solver_status=solver_status
                        
            # Compute metrics
            rmse_swim = np.sqrt(mean_squared_error(u_true, u_swim_test))  # mean squared error
            rel_err_swim = rmse_swim/np.sqrt(mean_squared_error(u_true, np.zeros_like(u_true)))

            # Compute metrics
            rmse_swim = np.sqrt(mean_squared_error(u_true, u_swim_test))
            rel_err_swim = rmse_swim/np.sqrt(mean_squared_error(u_true, np.zeros_like(u_true)))

            for r_m in rms:
                # Parameter sampler for ELM: Sample weights from a normal distribution and biases uniformly from [-4, 4]
                def sample_parameters_randomly(x, _, rng):
                    #weights = rng.normal(loc=0, scale=0.05, size=(x.shape[1], width))
                    weights = rng.uniform(low=-1.*r_m, high=r_m, size=(x.shape[1], w)) # low=-np.pi, high=np.pi,  2 * np.pi
                    biases = rng.uniform(low=-1.*r_m, high=r_m, size=(1, w)) # low=-np.pi, high=np.pi,  2 * np.pi
                    idx0 = None
                    idx1 = None
                    return weights, biases, idx0, idx1

                ansatz_elm = BasicAnsatz(
                    n_neurons=w,
                    activation="tanh",
                    random_state=seed,
                    regularization_scale=reg_const,
                    parameter_sampler = sample_parameters_randomly
                )  
                
                diffusion_solver_elm = Diffusion_Solver(
                    domain=domain, 
                    ansatz=ansatz_elm,
                    u0=u0,
                    boundary_condition=boundary_condition,
                    forcing=forcing,
                    regularization_scale=reg_const,
                    scale_boundary_correction=1.,
                    boundary_condition_true=analytical_sol
                )
                
                # ELM
                t_elm_start = time.time()
                sol_elm, solver_status_elm = diffusion_solver_elm.fit(t_span=[0, np.max(t_eval)], 
                                                        rtol = rtol, atol = atol, svd_cutoff= svd_cutoff,
                                                        outer_basis=False,
                                                        init_cond=ic_eval);
                t_elm_stop = time.time()
                time_elm = t_elm_stop - t_elm_start

                # Evaluate on test data
                u_elm_test = diffusion_solver_elm.evaluate(x_eval=X_test, t_eval = t_eval).T #, solver_status=solver_status
                            
                # Compute metrics
                rmse_elm = np.sqrt(mean_squared_error(u_true, u_elm_test))  # mean squared error
                rel_err_elm = rmse_elm/np.sqrt(mean_squared_error(u_true, np.zeros_like(u_true)))

                # Appedn SWIM metrics
                info.append(d)
                info.append(w)
                info.append(time_swim)
                info.append(rmse_swim)
                info.append(rel_err_swim)

                # Compute metrics
                rmse_elm = np.sqrt(mean_squared_error(u_true, u_elm_test))
                rel_err_elm = rmse_elm/np.sqrt(mean_squared_error(u_true, np.zeros_like(u_true)))
                info.append(time_elm)
                info.append(rmse_elm)
                info.append(rel_err_elm)
                info.append(reg_svd)
                info.append(r_m)
                info.append(sol_elm._V_a.shape[0])
                info.append(sol_swim._V_a.shape[0])
                print('dim = ',d,'width = ',w,  'reg_svd = ', reg_svd, 'r_m', r_m)
                print('w_outer_swim = ',sol_swim._V_a.shape[0],  'w_outer_elm = ', sol_elm._V_a.shape[0])
                print('rmse_swim =', rmse_swim, 'rel_err_swim =',rel_err_swim,'time =', time_swim)
                print('rmse_elm =', rmse_elm, 'rel_err_elm = ',rel_err_elm, 'time =', time_elm)
                experiments.append(info)

res = np.vstack(experiments)
np.save(filename, res) # save



res_reshape = res.reshape((len(dimensions), len(widths), len(reg_svd_constants), len(rms), -1))
res_reshape.shape

cmap = plt.cm.get_cmap("jet", len(widths))  # Use 'viridis' with a color for each row
fontsize = 12
fig, ax = plt.subplots(1, 1, figsize=(3, 3), sharey=True)

# For each dimension, for each width, I want to store the minimum of swim and elm losses
for d in range(len(dimensions)):
    re_swim_w = np.ones((len(widths), ))
    re_elm_w = np.ones((len(widths), ))
    w_o_swim_w = np.ones((len(widths), ))
    w_o_elm_w = np.ones((len(widths), ))
    filename_widths_errors = f'../data_0/{dimensions[d]}d_svd_finetuning.npy'
    info = []
    for w in range(len(widths)):
        rel_err_swim = res_reshape[d, w, :, :, 4].min()
        rel_err_elm = res_reshape[d, w, :, :, 7].min()
        #print(dimensions[d], widths[w], rel_err_swim, rel_err_elm)

        res_reshape_swim = res_reshape[d, w, :, :, 4]
        rel_err_swim_ind = np.unravel_index(np.argmin(res_reshape_swim, axis=None), res_reshape_swim.shape) #res_reshape_swim.argmin()
        
        res_reshape_elm = res_reshape[d, w, :, :, 7]
        rel_err_elm_ind = np.unravel_index(np.argmin(res_reshape_elm, axis=None), res_reshape_elm.shape) #res_reshape_swim.argmin()

        w_o_swim = res_reshape[d, w, rel_err_swim_ind[0], rel_err_swim_ind[1], -2]#[rel_err_swim_ind][0]
        w_o_elm = res_reshape[d, w, rel_err_elm_ind[0], rel_err_elm_ind[1], -1]#[rel_err_elm_ind][0]

        re_swim_w[w] = rel_err_swim
        re_elm_w[w] = rel_err_elm
        w_o_swim_w[w] = w_o_swim        
        w_o_elm_w[w] = w_o_elm
    
    info.append(widths)
    info.append(re_swim_w)
    info.append(re_elm_w)
    info.append(w_o_swim_w)
    info.append(w_o_elm_w)
    
    # Plot convergence plots
    ax.semilogy(widths, re_swim_w, linestyle='--', color=cmap(d),label=f"d = {dimensions[d]}")
    ax.semilogy(widths, re_elm_w, color=cmap(d))
    ax.set_xlabel('Width of hidden layer', fontsize=fontsize)
    ax.set_ylabel(r'Relative  ' + r'$\mathbb{L}_{2}$ ' + r'error', fontsize=fontsize)
    plt.tick_params(axis='both', labelsize=fontsize)
    plt.legend()
    fig.tight_layout()


    np.save(filename_widths_errors, np.vstack(info))

    print(info)




