import sys
sys.path.append('../../../../')
sys.path.append('../../../../src')
from swimpde import Domain
from swimpde import BasicAnsatz
from swimpde import BurgersSolver
import numpy as np
from sklearn.metrics import mean_squared_error
import scipy.io
import matplotlib.pyplot as plt
plt.rcParams['image.cmap'] = 'jet'
import argparse
import time

def main(args):
        # Check if at least one argument was passed
        if len(sys.argv) < 2:
            print("No arguments provided.")
            return

        # Load and visualize data
        data = scipy.io.loadmat('../../../../data/burgers_shock.mat')
        t_eval = data['t'].flatten()[:,None]
        x_eval = data['x'].flatten()[:,None]
        u_exact = np.real(data['usol']).T
        X, T = np.meshgrid(x_eval,t_eval)
        X_ = np.hstack((X.flatten()[:,None], T.flatten()[:,None]))

        # Set ground truth
        u_true = u_exact.flatten()[:,None]              

        # initial condition
        def u0(x):
            return -1 * np.sin(np.pi * x)

        # forcing
        def forcing(x, t):
            return np.zeros(x.shape[0])

        # boundary condition
        boundary_condition = "zero dirichlet" # 

        # Domain information and spacial points for the first time-block
        n_points_1d = 4000 #320 # 258 #square root of number of points
        x_lim = [-1, 1]
        rng = np.random.default_rng(seed=123)
        x_space = rng.uniform(x_lim[0], x_lim[1], n_points_1d).reshape((-1, 1))
        x_space_inner = x_space[1:-1]
        interior_points = x_space_inner

        # coordinates of boundary points (excluding corners)
        left = x_lim[0]
        right = x_lim[1]
        boundary_points = np.row_stack([left, right])

        # Get values of n_sample, time_blocks, ratios, grad_percent, reg_costs, tol from outside
        svd_co = args.r_c
        info = []
        info.append(args.n_s) # No. of sampling points (for computing gradients)
        info.append(args.n_c) # No. of collocation points (re-sampled)
        width = int(args.n_c/args.r)
        info.append(width) # Width
        info.append(args.t_b)
        info.append(args.r_c) # Regularization const
        info.append(svd_co) # Scaling factor for boundary conditions
        
        print(svd_co)
        # Compute prob. distribution for (re)-sampling collocation points
        def collocation_points_probabilities(df_dx):
            gradients = np.abs(df_dx)
            gradients = gradients + 0.01 * np.max(gradients)
            return gradients/np.sum(gradients)

        sample_test_points = np.sort(rng.uniform(x_lim[0] + 1e-4, x_lim[1] - 1e-4, args.n_s)).reshape((-1, 1)) # This does not include boundary points
        domain = Domain(
            interior_points=interior_points,
            boundary_points=boundary_points,
            sample_points = sample_test_points
        )
        seeds = [2]
        experiments = []

        # Loop over different seeds
        rel_err_swim = np.ones((len(seeds), ))
        time_swim = np.ones((len(seeds), ))
        rmse_swim = np.ones((len(seeds), ))
        j = 0

        # Loop over different seeds
        for seed in seeds: # Run for 3 seeds                                    
            ansatz_swim = BasicAnsatz(
                n_neurons=width,
                activation="tanh",
                random_state=seed,
                regularization_scale=args.r_c,
                parameter_sampler = 'tanh',
            )
            burgers_solver_swim = BurgersSolver(
                domain=domain, 
                ansatz=ansatz_swim,
                u0=u0,
                boundary_condition=boundary_condition,
                forcing=forcing,
                regularization_scale=args.r_c,
                c=(0.01/np.pi),
                ode_solver='LSODA'
            )
            # swim fit
            t_swim_start = time.time()
            sol_swim, solver_status_swim = burgers_solver_swim.fit_time_blocks(t_span=[0, np.max(t_eval)], 
                                    rtol=args.tol, atol=args.tol, 
                                    svd_cutoff=svd_co, 
                                    time_blocks=args.t_b, 
                                    prob_distr_resampling = collocation_points_probabilities, 
                                    n_col=args.n_c, outer_basis=False);
            t_swim_stop = time.time()
            time_swim[j] = t_swim_stop - t_swim_start


            # Evaluate on swim and SWIM
            pred_swim = burgers_solver_swim.evaluate_blocks(x_eval= x_eval, t_eval = t_eval, time_blocks = args.t_b, solver_status = solver_status_swim)

            # True and model solutions
            u_swim = pred_swim.T

            mse_swim = mean_squared_error(u_true, u_swim.flatten()[:,None])  # mean squared error
            rmse_swim[j] = np.sqrt(mse_swim)  # Root Mean Squared Error
            rel_err_swim[j] = np.linalg.norm(u_true-u_swim.flatten()[:,None], 2)/np.linalg.norm(u_true,2)
            print("rmse_swim, re_swim, time")
            print(rmse_swim[j], rel_err_swim[j],time_swim[j] )
            j += 1

        info.append(np.mean(time_swim))
        info.append(np.mean(rmse_swim))
        info.append(np.mean(rel_err_swim))

        print("n_s, n_c, width, time_blocks, r_c, r_m, svd_cutoff, time_swim, rmse_swim, re_swim")
        print(info)
        experiments.append(info)                  

if __name__ == "__main__":
    
    # Parse all the command line arguments
    parser = argparse.ArgumentParser(description="Processing arguments..")
    # Define expected arguments
    parser.add_argument("--n_s", type=int, required=True, help="Number of sample points")
    parser.add_argument("--t_b", type=int, required=True, help="Time blocks")
    parser.add_argument("--n_c", type=int, required=True, help="Number of collocation pts")
    parser.add_argument("--r",  type=float, required=True, help="Ratio: n_col/width")
    parser.add_argument("--r_c", type=float, required=True, help="regularization constant")
    parser.add_argument("--tol", type=float, default=1e-3, help="atol and rtol for the ODE solver")
    args = parser.parse_args()

    print(args.n_s, args.n_c, args.t_b, args.r, args.r_c, args.tol)
    print(args.r_c)
    main(args)




                            
                            
                            




