# Importing Libraries

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils
import torch.utils.data
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from math import pi
from PINN import NeuralNet
from PINN import init_xavier
import time
import argparse
from sklearn.metrics import mean_squared_error
import csv
import os
import scipy.io
from matplotlib.colors import Normalize
import matplotlib.ticker as ticker

eps= 5

# initial condition

def initial_condition(x):
    return -torch.sin(pi*x)

# assigning number of points
initial_pts = 200
left_boundary_pts = 200
right_boundary_pts = 200
residual_pts = 40000

# Type of optimizer (ADAM or LBFGS)
opt_type = "LBFGS"

#first condition
x_init = 2*torch.rand((initial_pts,1)) - 1 # initial pts 
t_init =  0*x_init                            #t=0
init =  torch.cat([x_init, t_init],1)                   #concatenate the variable x and t
u_init = initial_condition(init[:,0]).reshape(-1, 1)     # (initial condition)

#second condition
xb_left = -torch.ones((left_boundary_pts, 1)) # left spatial boundary 
tb_left = torch.rand((left_boundary_pts, 1)) #randomly generate the value of t
b_left = torch.cat([xb_left, tb_left ],1)    #concatenate
u_b_l = 0*xb_left                            # [boundary condition]


#third condition
xb_right = torch.ones((right_boundary_pts, 1)) # right spatial boundary
tb_right = torch.rand((right_boundary_pts, 1)) # right boundary pts
b_right = torch.cat([xb_right, tb_right ],1)   #concatenate
u_b_r = 0*xb_right                              #[boundary condition]

#fourth conditon for interior points
x_interior = 2*pi*torch.rand((residual_pts, 1)) -1   #randomy generate the value of x in domain
t_interior = torch.rand((residual_pts, 1))        #randomly generate the value of t in [0,1]
interior = torch.cat([x_interior, t_interior],1)   #concatinate

x_int = torch.linspace(-1, 1, 202)
x_int = x_int[1:-1]

t_int = torch.linspace(0, 1, 202)
t_int = t_int[1:-1]

x_interior = x_int.tile((200,))
x_interior = x_interior.reshape(-1,1)

t_interior = t_int.repeat_interleave(200)
t_interior = t_interior.reshape(-1,1)

# torch.set_printoptions(threshold=10_000)

interior = torch.cat([x_interior, t_interior],1)

n = 200  # size of matrix
W = torch.tril(torch.ones(n, n), diagonal=-1)  # create a lower triangular matrix of ones
W -= torch.diag(torch.diag(W))  # set the diagonal elements to zero

training_set = DataLoader(torch.utils.data.TensorDataset(init, u_init, b_left, b_right), batch_size=200, shuffle=False)
    
my_network = NeuralNet(input_dimension = init.shape[1], output_dimension = u_init.shape[1], n_hidden_layers=9, neurons=20)

# Random Seed for weight initialization
retrain = 128
# Xavier weight initialization
init_xavier(my_network, retrain)

if opt_type == "ADAM":
    optimizer_ = optim.Adam(my_network.parameters(), lr=0.001)
elif opt_type == "LBFGS":
    optimizer_ = optim.LBFGS(my_network.parameters(), lr=0.1, max_iter=1, max_eval=50000, tolerance_change=1.0 * np.finfo(float).eps)
else:
    raise ValueError("Optimizer not recognized")
    
# Define the optimizers
adam_optimizer = optim.Adam(my_network.parameters(), lr=0.001)
lbfgs_optimizer = optim.LBFGS(my_network.parameters(), lr=0.1, max_iter=1, max_eval=50000, tolerance_change=1.0 * np.finfo(float).eps)


def fit(model, training_set, interior, num_epochs, optimizer, p, verbose=True):
    history = list()
    start_total_time = time.time()  # Start timing for the entire training process

    # Loop over epochs
    for epoch in range(num_epochs):
        if verbose: 
            print("################################ ", epoch, " ################################")
        
        start_epoch_time = time.time()  # Start timing for the epoch

        running_loss = list([0])

        # Loop over batches
        for j, (initial, u_initial, bd_left, bd_right) in enumerate(training_set):

            def closure():
                # zero the parameter gradients
                optimizer.zero_grad()
                # for initial
                u_initial_pred_ = model(initial)
                # for left boundary
                u_bd_left_pred_ = model(bd_left)
                # for right boundary
                u_bd_right_pred_ = model(bd_right)

                # residual calculation
                
                interior.requires_grad = True
                u_hat = model(interior)
                inputs = torch.ones(residual_pts, 1 )
                inputs2 = torch.ones(residual_pts, 1)
                grad_u_hat = torch.autograd.grad(u_hat, interior, grad_outputs=inputs, create_graph=True)[0]

                u_x = grad_u_hat[:, 0]
                u_t =  grad_u_hat[:, 1]

                grad_grad_u_x = torch.autograd.grad(u_x, interior, grad_outputs=torch.ones(interior.shape[0]), create_graph=True)[0]
                u_xx = grad_grad_u_x[:, 0]
               
                pde_single_column = (u_t.reshape(-1, ) + (u_hat.reshape(-1, )*u_x.reshape(-1, )).reshape(-1, ) - (0.01/pi)*u_xx.reshape(-1, ))**2
                                    
                pde_single_column = pde_single_column.reshape(-1, 1)

                pde_matrix = pde_single_column.reshape(200, 200)

                loss_at_time_steps = torch.mean(pde_matrix, 1)
                loss_at_time_steps = loss_at_time_steps.reshape(-1, 1)

                with torch.no_grad():
                    weighted_loss = torch.matmul(W, loss_at_time_steps)
                weighted_loss = torch.exp(-eps * weighted_loss)

                loss_pde = torch.mean(weighted_loss * loss_at_time_steps)


                # Item 1. below
                loss = torch.mean((u_initial_pred_.reshape(-1, ) - u_initial.reshape(-1, ))**p) + torch.mean((u_t.reshape(-1, ) + (u_hat.reshape(-1, )*u_x.reshape(-1, )).reshape(-1, ) - (0.01/pi)*u_xx.reshape(-1, ))**p) +torch.mean((u_bd_left_pred_.reshape(-1,))**p) + torch.mean((u_bd_right_pred_.reshape(-1,))**p)


                # Item 2. below
                loss.backward()
                # Compute average training loss over batches for the current epoch
                running_loss[0] += loss.item()
                return loss

            # Item 3. below
            optimizer.step(closure=closure)

        print('Loss: ', (running_loss[0] / len(training_set)))
        history.append(running_loss[0])

    return history

# Training parameters
#num_epochs_adam = 5000
#num_epochs_lbfgs = 10000

# Train with Adam optimizer
#print("Training with Adam optimizer")
#start_time = time.time()
#history_adam = fit(my_network, training_set, interior, num_epochs_adam, adam_optimizer, p=2, verbose=True)
#end_time = time.time()
#total_time_adam = end_time - start_time
#print("Training time with Adam: {:.2f} seconds".format(total_time_adam))

# Train with LBFGS optimizer
#print("Training with LBFGS optimizer")
#start_time = time.time()
#history_lbfgs = fit(my_network, training_set, interior, num_epochs_lbfgs, lbfgs_optimizer, p=2, verbose=True)
#end_time = time.time()
#total_time_lbfgs = end_time - start_time
#print("Training time with LBFGS: {:.2f} seconds".format(total_time_lbfgs))

# Concatenate the histories
#total_history = history_adam + history_lbfgs

# Optionally, save the total history or analyze it
# Example: print the total combined training time
#total_time = total_time_adam + total_time_lbfgs
#print("Total training time: {:.2f} seconds".format(total_time))

# Load the .mat file
mat_data = scipy.io.loadmat('burgers_shock.mat')

# print(mat_data)
# Access the variables stored in the .mat file
# The variable names in the .mat file become keys in the loaded dictionary
x_test = mat_data['x']
t_test = mat_data['t']
u_exact = mat_data['usol']

x_test = x_test.astype(np.float32)
t_test = t_test.astype(np.float32)
u_exact = u_exact.astype(np.float32)

x_test = torch.from_numpy(x_test)
t_test = torch.from_numpy(t_test)
u_exact = torch.from_numpy(u_exact)

# Create a grid of x and t values
x_grid, t_grid = torch.meshgrid(x_test.squeeze(), t_test.squeeze(), indexing='ij')

# Reshape the grids to create the test tensor
x_flat = x_grid.reshape(-1, 1)
t_flat = t_grid.reshape(-1, 1)
test = torch.cat([x_flat, t_flat], dim=1)

my_network.load_state_dict(torch.load('Trained_Models/model1.pth'))
#my_network = my_network.cpu()
u_test_pred = my_network(test).reshape(-1,1)

#torch.save(my_network.state_dict(), 'Trained_Models/model3.pth')

# Compute the relative L2 error norm (generalization error)

u_test_pred = u_test_pred.reshape(-1, )
u_exact = u_exact.reshape(-1, )

RMSE = np.sqrt(mean_squared_error(u_test_pred.detach().numpy(), u_exact.detach().numpy()))
print("RMSE Test: ", RMSE)

relative_error_test = torch.sqrt(torch.mean((u_test_pred - u_exact)**2) / torch.mean(u_exact**2))
print("Relative Error Test: ", relative_error_test.detach().numpy())

# Function to append result to CSV file
def append_to_csv(file_name, result):
    file_exists = os.path.isfile(file_name)
    
    with open(file_name, mode='a', newline='') as file:
        writer = csv.writer(file)
        
        # Write header only if file does not exist
        if not file_exists:
            writer.writerow(['Result'])
        
        # Append the result
        writer.writerow([result])
        
# File names for storing results
RMSE_file = 'Results/RMSE_results.csv'
Rel_err_file = 'Results/Rel_err_results.csv'
Time_file = 'Results/Time_results.csv'

# Append results to respective CSV files
#append_to_csv(RMSE_file, RMSE)
#append_to_csv(Rel_err_file, relative_error_test.detach().numpy())
#append_to_csv(Time_file, total_time)


# Detach and convert tensors to numpy arrays
u_test_pred = u_test_pred.detach().numpy()
x_test = x_test.detach().numpy()
t_test = t_test.detach().numpy()
u_exact = u_exact.detach().numpy()

# Compute the absolute error
abs_error = np.abs(u_test_pred - u_exact)

abs_error = abs_error.reshape(256, 100)

fontsize = 22

aspect = 0.2

# visualize the solution
fig, ax = plt.subplots(1, 1, figsize=(6, 5), constrained_layout=True)
extent = [t_test.min(), t_test.max(), x_test.min(), x_test.max()]

sol_img1 = ax.imshow(abs_error, extent=extent, origin='lower', aspect=aspect, cmap='jet')

# Colorbar
cb = fig.colorbar(sol_img1, ax=ax, location='bottom', aspect=20)


tick_locator = ticker.MaxNLocator(nbins=3)

cb.locator = tick_locator

cb.formatter = ticker.FuncFormatter(lambda x, pos: f'{x:.1e}')

cb.update_ticks()

cb.ax.tick_params(labelsize=fontsize) # Change 12 to your desired font size

ax.set_xlabel('t', fontsize=fontsize)

ax.set_ylabel('x', fontsize=fontsize)

ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%.1f'))

ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.1f'))

plt.tick_params(axis='both', labelsize=fontsize)



# Save the figure
plt.savefig('Results/Figures/causal_abs_err_burgers.png', dpi=300, bbox_inches="tight")
plt.show()