# Importing Libraries

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils
import torch.utils.data
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from math import pi
from PINN import NeuralNet
from PINN import init_xavier
import time
import argparse
from sklearn.metrics import mean_squared_error
import csv
import os
import pylab as p
import matplotlib.cm as cm

device = torch.device('cpu')
dtype = torch.float32

# Load .npy file
train_interior = np.load('Data/HEAT_HD/3d/Interior_train_points.npy')
train_initial = np.load('Data/HEAT_HD/3d/Initial_train_points.npy')
train_boundary = np.load('Data/HEAT_HD/3d/boundary_train_points.npy')

test_interior = np.load('Data/HEAT_HD/3d/interior_test_points.npy')
test_boundary = np.load('Data/HEAT_HD/3d/boundary_test_points.npy')
test_tensor = np.concatenate((test_interior, test_boundary))



def f(t, x, y, z):
    """
    Compute f(x, t) = (1/d - 1) * cos((1/d) * sum(x_i)) * exp(-t),
    where x = [x, y, z], and d = 3.
    """
    d = 3  # Number of spatial dimensions
    factor = 1 / d - 1
    cos_term = torch.cos((1 / d) * (x + y + z))
    exp_term = torch.exp(-t)
    return factor * cos_term * exp_term

def g(t, x, y, z):
    """
    Compute g(x, t) = cos((1/d) * sum(x_i)) * exp(-t),
    where x = [x, y, z], and d = 3.
    """
    d = 3  # Number of spatial dimensions
    cos_term = torch.cos((1 / d) * (x + y + z))
    exp_term = torch.exp(-t)
    return cos_term * exp_term

def h(t, x, y, z):
    """
    Compute h(x) = cos((1/d) * sum(x_i)),
    where x = [x, y, z], and d = 3.
    """
    d = 3  # Number of spatial dimensions
    return torch.cos((1 / d) * (x + y + z))

def exact_solution(t, x, y, z):
    """
    Compute u(x, t) = cos((1/d) * sum(x_i)) * exp(-t),
    where x = [x, y, z], and d = 3.
    """
    d = 3  # Number of spatial dimensions
    cos_term = torch.cos((1 / d) * (x + y + z))
    exp_term = torch.exp(-t)
    return cos_term * exp_term


    
# Type of optimizer (ADAM or LBFGS)
opt_type = "LBFGS"


train_interior_tensor = torch.tensor(train_interior, dtype = dtype)
train_boundary_tensor = torch.tensor(train_boundary, dtype = dtype)
train_initial_tensor = torch.tensor(train_initial, dtype = dtype)
test_interior_tensor = torch.tensor(test_interior, dtype = dtype)
test_boundary_tensor = torch.tensor(test_boundary, dtype = dtype)
test_tensor = torch.tensor(test_tensor, dtype = dtype)



train_boundary_value_tensor = g(train_boundary_tensor[:,0], train_boundary_tensor[:,1], train_boundary_tensor[:,2], train_boundary_tensor[:,3]).reshape(-1,1)

train_initial_value_tensor = h(train_initial_tensor[:,0], train_initial_tensor[:,1], train_initial_tensor[:,2], train_initial_tensor[:,3]).reshape(-1,1)


# print("interior dataset", train_interior_tensor.shape)
#################################################


training_set = DataLoader(torch.utils.data.TensorDataset(train_boundary_tensor.to(device), train_boundary_value_tensor.to(device), train_initial_tensor.to(device), train_initial_value_tensor.to(device)), batch_size=4000, shuffle=False)
    
my_network = NeuralNet(input_dimension = 4, output_dimension = 1, n_hidden_layers=4, neurons=20)

# Random Seed for weight initialization
retrain = 3
# Xavier weight initialization
init_xavier(my_network, retrain)

if opt_type == "ADAM":
    optimizer_ = optim.Adam(my_network.parameters(), lr=0.001)
elif opt_type == "LBFGS":
    optimizer_ = optim.LBFGS(my_network.parameters(), lr=0.1, max_iter=1, max_eval=50000, tolerance_change=1.0 * np.finfo(float).eps)
else:
    raise ValueError("Optimizer not recognized")


def fit(model, training_set, interior, num_epochs, optimizer, p, verbose=True):
    history = list()

    # Loop over epochs
    for epoch in range(num_epochs):
        if verbose: print("################################ ", epoch, " ################################")

        running_loss = list([0])

        # Loop over batches
        for j, (bd, bd_val, initial, initial_value) in enumerate(training_set):
            def closure():
                # zero the parameter gradients
                optimizer.zero_grad()
                # for bottom boundary
                u_bd_pred_ = model(bd)
                
                u_initial_pred_ = model(initial)

                # residual calculation
                interior.requires_grad = True
                u_hat = model(train_interior_tensor)
                inputs = torch.ones(20000, 1).to(device)
                grad_u_hat = torch.autograd.grad(u_hat, interior, grad_outputs=inputs, create_graph=True)[0]

                u_t = grad_u_hat[:, 0]
                u_x = grad_u_hat[:, 1]
                u_y = grad_u_hat[:, 2]
                u_z = grad_u_hat[:, 3]
                
               
                grad_grad_u_x = \
                torch.autograd.grad(u_x, interior, grad_outputs=torch.ones(interior.shape[0]), create_graph=True)[0]
                u_xx = grad_grad_u_x[:, 1]

                grad_grad_u_y = \
                torch.autograd.grad(u_y, interior, grad_outputs=torch.ones(interior.shape[0]), create_graph=True)[0]
                u_yy = grad_grad_u_y[:, 2]
                
                grad_grad_u_z = \
                torch.autograd.grad(u_z, interior, grad_outputs=torch.ones(interior.shape[0]), create_graph=True)[0]
                u_zz = grad_grad_u_z[:, 3]


                # Item 1. below
                loss = (torch.mean((u_bd_pred_.reshape(-1, ) - bd_val.reshape(-1, )) ** p) + torch.mean((u_initial_pred_.reshape(-1, ) - initial_value.reshape(-1, )) ** p) +
                        torch.mean((u_t.reshape(-1, ) - u_xx.reshape(-1, ) - u_yy.reshape(-1, ) - u_zz.reshape(-1, ) - f(train_interior_tensor[:, 0], train_interior_tensor[:, 1], train_interior_tensor[:, 2], train_interior_tensor[:, 3] ) ) ** p))

                # Item 2. below
                loss.backward()
                # Compute average training loss over batches for the current epoch
                running_loss[0] += loss.item()
                return loss

            # Item 3. below
            optimizer.step(closure=closure)

        print('Loss: ', (running_loss[0] / len(training_set)))
        history.append(running_loss[0])

    return history

n_epochs = 1000
start_time = time.time()
history = fit(my_network, training_set, train_interior_tensor, n_epochs, optimizer_, p=2, verbose=True )
end_time = time.time()
total_time = end_time - start_time
print("Training time: {:.2f} seconds".format(total_time))

u_exact = exact_solution(test_tensor[:,0], test_tensor[:,1], test_tensor[:,2], test_tensor[:,3]).reshape(-1,1)

# my_network.load_state_dict(torch.load('Trained_Models/heat_hd/model1.pth'))

my_network = my_network.cpu()
u_test_pred = my_network(test_tensor).reshape(-1,1)

#torch.save(my_network.state_dict(), 'Trained_Models/sine/model3.pth')

# Compute the relative L2 error norm (generalization error)

RMSE = np.sqrt(mean_squared_error(u_test_pred.detach().numpy(), u_exact.detach().numpy()))
print("RMSE Test: ", RMSE)

u_test_pred = u_test_pred.reshape(-1, )
u_exact = u_exact.reshape(-1, )

relative = torch.abs((u_test_pred - u_exact) / u_exact)

relative_error_test = torch.sqrt(torch.mean((u_test_pred - u_exact)**2) / torch.mean(u_exact**2))
print("Relative Error Test: ", relative_error_test.detach().numpy())

# Function to append result to CSV file
def append_to_csv(file_name, result):
    file_exists = os.path.isfile(file_name)
    
    with open(file_name, mode='a', newline='') as file:
        writer = csv.writer(file)
        
        # Write header only if file does not exist
        if not file_exists:
            writer.writerow(['Result'])
        
        # Append the result
        writer.writerow([result])
        
# File names for storing results
# RMSE_file = 'Results/heat_3d/RMSE_results.csv'
# Rel_err_file = 'Results/heat_3d/Rel_err_results.csv'
# Time_file = 'Results/heat_3d/Time_results.csv'

# # Append results to respective CSV files
# append_to_csv(RMSE_file, RMSE)
# append_to_csv(Rel_err_file, relative_error_test.detach().numpy())
# append_to_csv(Time_file, total_time)

# u_test_pred = u_test_pred.reshape(-1,)
# relative = relative.reshape(-1, )

# abs_err = torch.abs(u_test_pred - u_exact)

# ### contour plotting

# fontsize = 12
# cmap =cm.jet


# # Plot of ground truth
# fontsize = 12
# cmap =cm.jet

# # Plot of absolute error
# fig = p.figure(figsize=(4, 3))
# ax =fig.add_subplot(projection='3d')#
# rel_err_plot = ax.scatter(test_tensor[:,0].detach().numpy(), test_tensor[:,1].detach().numpy(), test_tensor[:,2].detach().numpy(), c=abs_err.detach().numpy(), 
#                           marker='o', cmap=cmap) 
# cb = fig.colorbar(rel_err_plot , ax=ax , location='bottom',fraction=0.046, format='%.0e')
# cb.ax.tick_params(labelsize=fontsize)
# plt.tick_params(axis='both', labelsize=fontsize)
# plt.savefig('Results/Figures/sine/sine_abs_err.pdf', bbox_inches="tight")
# plt.show()