# Importing Libraries

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from math import pi
from PINN import NeuralNet, init_xavier
import time
from sklearn.metrics import mean_squared_error
import csv
import os
import pylab as p
import matplotlib.cm as cm

import sys
sys.path.append('../../')
sys.path.append('../../src')
# from swimpde import Reaction_Diffusion_Solver
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
import matplotlib.cm as cm
import time
cmap = cm.jet
from utils import *
from mpl_toolkits.mplot3d import Axes3D

device = torch.device('cpu')
dtype = torch.float32

# Load .npy file
train_interior = np.load('X_int_train.npy')
train_boundary = np.load('X_b_train.npy')
test_interior = np.load('X_test.npy')
test_boundary = np.load('X_b_test.npy')
#test_tensor = np.concatenate((test_interior, test_boundary))
test_tensor = test_interior

print("train_interior shape", train_interior.shape)
print("train_boundary shape", train_interior.shape)
print("test shape", test_interior.shape)

# Define the exact solution for 5D
# Exact solution as per the image (2D case)
def exact_solution(t, x1, x2, x3, x4, x5):
    return 2 * torch.sin((pi/2)*x1) * torch.cos((pi/2)*x2) * torch.exp(-t)


# Define boundary condition using exact solution
def BC(t, x1, x2, x3, x4, x5):
    return exact_solution(t, x1, x2, x3, x4, x5)

# Define the initial condition using t=0 in the exact solution
def initial_condition(x1, x2, x3, x4, x5):
    return 2 * torch.sin((pi/2)*x1) * torch.cos((pi/2)*x2)

# Type of optimizer (ADAM or LBFGS)
opt_type = "LBFGS"

train_interior_tensor = torch.tensor(train_interior, dtype=dtype)
train_boundary_tensor = torch.tensor(train_boundary, dtype=dtype)
#test_interior_tensor = torch.tensor(test_interior, dtype=dtype)
#test_boundary_tensor = torch.tensor(test_boundary, dtype=dtype)
test_tensor = torch.tensor(test_tensor, dtype=dtype)

train_boundary_value_tensor = BC(train_boundary_tensor[:,0], train_boundary_tensor[:,1],
                                 train_boundary_tensor[:,2], train_boundary_tensor[:,3],
                                 train_boundary_tensor[:,4], train_boundary_tensor[:,5]).reshape(-1,1)

training_set = DataLoader(torch.utils.data.TensorDataset(train_boundary_tensor.to(device), train_boundary_value_tensor.to(device)), batch_size=1000, shuffle=False)

my_network = NeuralNet(input_dimension=6, output_dimension=1, n_hidden_layers=4, neurons=20)

# Xavier weight initialization
# init_xavier(my_network, retrain=1)

if opt_type == "ADAM":
    optimizer_ = optim.Adam(my_network.parameters(), lr=0.001)
elif opt_type == "LBFGS":
    optimizer_ = optim.LBFGS(my_network.parameters(), lr=0.1, max_iter=1, max_eval=50000, tolerance_change=1.0 * np.finfo(float).eps)
else:
    raise ValueError("Optimizer not recognized")
    
    

# def fit(model, training_set, interior, num_epochs, optimizer, p, verbose=True):
#     history = list()

#     for epoch in range(num_epochs):
#         if verbose: print("################################ ", epoch, " ################################")

#         running_loss = list([0])

#         for j, (bd, bd_val) in enumerate(training_set):
#             def closure():
#                 optimizer.zero_grad()
#                 u_bd_pred_ = model(bd)

#                 interior.requires_grad = True
#                 u_hat = model(train_interior_tensor)
#                 ones = torch.ones(interior.shape[0], 1).to(device)
#                 grad_u_hat = torch.autograd.grad(u_hat, interior, grad_outputs=ones, create_graph=True)[0]

#                 u_t, u_x1, u_x2, u_x3, u_x4, u_x5 = grad_u_hat[:,0], grad_u_hat[:,1], grad_u_hat[:,2], grad_u_hat[:,3], grad_u_hat[:,4], grad_u_hat[:,5]

#                 u_xx1 = torch.autograd.grad(u_x1, interior, grad_outputs=ones.squeeze(), create_graph=True)[0][:,1]
#                 u_xx2 = torch.autograd.grad(u_x2, interior, grad_outputs=ones.squeeze(), create_graph=True)[0][:,2]
#                 u_xx3 = torch.autograd.grad(u_x3, interior, grad_outputs=ones.squeeze(), create_graph=True)[0][:,3]
#                 u_xx4 = torch.autograd.grad(u_x4, interior, grad_outputs=ones.squeeze(), create_graph=True)[0][:,4]
#                  u_xx5 = torch.autograd.grad(u_x5, interior, grad_outputs=ones.squeeze(), create_graph=True)[0][:,5]

#                 x1 = interior[:,1]
#                 x2 = interior[:,2]
#                 x3 = interior[:,3]
#                 x4 = interior[:,4]
                
# def fit(model, training_set, interior, num_epochs, optimizer, p, verbose=True):
#     history = list()

#     # Filter points where t = 0 for initial condition
#     initial_points = interior[torch.isclose(interior[:, 0], torch.tensor(0.0, dtype=dtype))]
#     initial_values = initial_condition(initial_points[:, 1], initial_points[:, 2], initial_points[:, 3], initial_points[:, 4]).reshape(-1, 1)

#     for epoch in range(num_epochs):
#         if verbose: print("################################ ", epoch, " ################################")

#         running_loss = list([0])

#         for j, (bd, bd_val) in enumerate(training_set):
#             def closure():
#                 optimizer.zero_grad()
#                 u_bd_pred_ = model(bd)

#                 interior.requires_grad = True
#                 u_hat = model(interior)
#                 ones = torch.ones(interior.shape[0], 1).to(device)
#                 grad_u_hat = torch.autograd.grad(u_hat, interior, grad_outputs=ones, create_graph=True)[0]

#                 u_t, u_x1, u_x2, u_x3, u_x4 = grad_u_hat[:,0], grad_u_hat[:,1], grad_u_hat[:,2], grad_u_hat[:,3], grad_u_hat[:,4]

#                 u_xx1 = torch.autograd.grad(u_x1, interior, grad_outputs=ones.squeeze(), create_graph=True)[0][:,1]
#                 u_xx2 = torch.autograd.grad(u_x2, interior, grad_outputs=ones.squeeze(), create_graph=True)[0][:,2]
#                 u_xx3 = torch.autograd.grad(u_x3, interior, grad_outputs=ones.squeeze(), create_graph=True)[0][:,3]
#                 u_xx4 = torch.autograd.grad(u_x4, interior, grad_outputs=ones.squeeze(), create_graph=True)[0][:,4]

#                 x1 = interior[:,1]
#                 x2 = interior[:,2]
#                 x3 = interior[:,3]
#                 x4 = interior[:,4]
#                 t = interior[:,0]

#                 laplacian = u_xx1 + u_xx2 + u_xx3 + u_xx4
#                 f_term = ((pi**2 - 2) * torch.sin((pi/2)*x1) * torch.cos((pi/2)*x2) * torch.exp(-t) -
#                           4 * torch.sin((pi/2)*x1)**2 * torch.cos((pi/2)*x1)**2 * torch.cos((pi/2)*x2)**2 * torch.exp(-2*t))

#                 # Initial condition prediction
#                 u_ic_pred = model(initial_points)
#                 loss_ic = torch.mean((u_ic_pred - initial_values.to(device))**2)

#                 # Total loss
#                 loss = (
#                     torch.mean((u_bd_pred_.reshape(-1,) - bd_val.reshape(-1,))**p) +
#                     0.01*torch.mean((u_t - laplacian - u_hat**2 - f_term)**2) +
#                      loss_ic
#                 )

#                 loss.backward()
#                 running_loss[0] += loss.item()
#                 return loss

#             optimizer.step(closure=closure)

#         print('Loss: ', (running_loss[0] / len(training_set)))
#         history.append(running_loss[0])

#     return history


def fit(model, training_set, interior, num_epochs, optimizer, p, verbose=True):
    history = list()

    # Loop over epochs
    for epoch in range(num_epochs):
        if verbose: print("################################ ", epoch, " ################################")

        running_loss = list([0])

        # Loop over batches
        for j, (bd, bd_val) in enumerate(training_set):
            def closure():
                # zero the parameter gradients
                optimizer.zero_grad()
                # for bottom boundary
                u_bd_pred_ = model(bd)

                # residual calculation
                interior.requires_grad = True
                u_hat = model(train_interior_tensor)
                
                
                inputs = torch.ones(1000, 1).to(device)
                grad_u_hat = torch.autograd.grad(u_hat, interior, grad_outputs=inputs, create_graph=True)[0]

                u_t = grad_u_hat[:, 0]
                u_x1 = grad_u_hat[:, 1]
                u_x2 = grad_u_hat[:, 2]
                u_x3 = grad_u_hat[:, 3]
                u_x4 = grad_u_hat[:, 4]
                u_x5 = grad_u_hat[:, 5]

                grad_grad_u_t = \
                torch.autograd.grad(u_t, interior, grad_outputs=torch.ones(interior.shape[0]), create_graph=True)[0]
                u_tt = grad_grad_u_t[:, 0]

                grad_grad_u_x1 = \
                torch.autograd.grad(u_x1, interior, grad_outputs=torch.ones(interior.shape[0]), create_graph=True)[0]
                u_x1x1 = grad_grad_u_x1[:, 1]
                
                grad_grad_u_x2 = \
                torch.autograd.grad(u_x2, interior, grad_outputs=torch.ones(interior.shape[0]), create_graph=True)[0]
                u_x2x2 = grad_grad_u_x2[:, 2]
                
                grad_grad_u_x3 = \
                torch.autograd.grad(u_x3, interior, grad_outputs=torch.ones(interior.shape[0]), create_graph=True)[0]
                u_x3x3 = grad_grad_u_x3[:, 3]
                
                grad_grad_u_x4 = \
                torch.autograd.grad(u_x4, interior, grad_outputs=torch.ones(interior.shape[0]), create_graph=True)[0]
                u_x4x4 = grad_grad_u_x4[:, 4]
                
                grad_grad_u_x5 = \
                torch.autograd.grad(u_x5, interior, grad_outputs=torch.ones(interior.shape[0]), create_graph=True)[0]
                u_x5x5 = grad_grad_u_x5[:, 5]

               
                t = interior[:,0]
                x1 = interior[:,1]
                x2 = interior[:,2]
                x3 = interior[:,3]
                x4 = interior[:,4]
                x5 = interior[:,5]
                
                
                f_term = ((pi**2 - 2) * torch.sin((pi/2)*x1) * torch.cos((pi/2)*x2) * torch.exp(-t) - 4 * torch.sin((pi/2)*x1)**2 * torch.cos((pi/2)*x1)**2 * torch.cos((pi/2)*x2)**2 * torch.exp(-2*t))

                # Item 1. below
                loss = (torch.mean((u_bd_pred_.reshape(-1, ) - bd_val.reshape(-1, )) ** p) +
                        0.01*torch.mean((u_t.reshape(-1, ) - u_x1x1.reshape(-1, ) - u_x2x2.reshape(-1, ) - u_x3x3.reshape(-1, ) - u_x4x4.reshape(-1, ) - u_x5x5.reshape(-1, ) - (u_hat**2).reshape(-1, ) + f_term.reshape(-1, ) ) ** p))

                # Item 2. below
                loss.backward()
                # Compute average training loss over batches for the current epoch
                running_loss[0] += loss.item()
                return loss

            # Item 3. below
            optimizer.step(closure=closure)

        print('Loss: ', (running_loss[0] / len(training_set)))
        history.append(running_loss[0])

    return history

# Example training (commented)
n_epochs = 10000
start_time = time.time()
history = fit(my_network, training_set, train_interior_tensor, n_epochs, optimizer_, p=2, verbose=True)
end_time = time.time()
total_time = end_time - start_time
print("Training time: {:.2f} seconds".format(total_time))

u_exact = exact_solution(test_tensor[:,0], test_tensor[:,1], test_tensor[:,2], test_tensor[:,3], test_tensor[:,4], test_tensor[:,5]).reshape(-1,1)

# # Load trained model
# my_network.load_state_dict(torch.load('Trained_Models/sine/model3.pth'))

my_network = my_network.cpu()
u_test_pred = my_network(test_tensor).reshape(-1,1)

# Compute RMSE
RMSE = np.sqrt(mean_squared_error(u_test_pred.detach().numpy(), u_exact.detach().numpy()))
print("RMSE Test: ", RMSE)

u_test_pred = u_test_pred.reshape(-1, )
u_exact = u_exact.reshape(-1, )

relative_error_test = torch.sqrt(torch.mean((u_test_pred - u_exact)**2) / torch.mean(u_exact**2))
print("Relative Error Test: ", relative_error_test.detach().numpy())

# # Visualization
# abs_err = torch.abs(u_test_pred - u_exact)

# fontsize = 12
# cmap = cm.jet

# fig = p.figure(figsize=(4, 3))
# ax = fig.add_subplot(projection='2d')
# rel_err_plot = ax.scatter(test_tensor[:,1].detach().numpy(), test_tensor[:,2].detach().numpy(), c=abs_err.detach().numpy(),
#                           marker='o', cmap=cmap)
# cb = fig.colorbar(rel_err_plot , ax=ax , location='bottom', fraction=0.046, format='%.0e')
# cb.ax.tick_params(labelsize=fontsize)
# plt.tick_params(axis='both', labelsize=fontsize)
# plt.savefig('Results/Figures/abs_err_lbfgs.pdf', bbox_inches="tight")
# plt.show()

# Function to append result to CSV file
def append_to_csv(file_name, result):
    file_exists = os.path.isfile(file_name)
    
    with open(file_name, mode='a', newline='') as file:
        writer = csv.writer(file)
        
        # Write header only if file does not exist
        if not file_exists:
            writer.writerow(['Result'])
        
        # Append the result
        writer.writerow([result])
        
# File names for storing results
RMSE_file = 'Results/RMSE_results.csv'
Rel_err_file = 'Results/Rel_err_results.csv'
Time_file = 'Results/Time_results.csv'

# Append results to respective CSV files
append_to_csv(RMSE_file, RMSE)
append_to_csv(Rel_err_file, relative_error_test.detach().numpy())
append_to_csv(Time_file, total_time)