import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils
import torch.utils.data
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from math import pi
from PINN import NeuralNet
from PINN import init_xavier
import time
import time
import argparse
from sklearn.metrics import mean_squared_error
import csv
import os
import matplotlib.ticker as ticker


# Exact solution
def exact_solution(x, t):
    return torch.sin(x)*torch.cos(4*pi*t)

# Initial condition
def initial_condition(x):
    return torch.sin(x)

def initial_condition_time(x):
    return 0*np.pi**2*torch.sin(np.pi*x)

# assigning number of points
initial_pts = 2000
left_boundary_pts = 2000
right_boundary_pts = 2000
residual_pts = 10000

# Type of optimizer (ADAM or LBFGS)
opt_type = "LBFGS"

# initial points
x_init = pi*torch.rand((initial_pts,1)) # initial pts
t_init =  0*x_init
init =  torch.cat([x_init, t_init],1)
u_init = initial_condition(init[:,0]).reshape(-1, 1)
u_init_t = initial_condition_time(init[:,0]).reshape(-1, 1) #new

#boundary points

xb_left = torch.zeros((left_boundary_pts, 1)) # left spatial boundary
tb_left = torch.rand((left_boundary_pts, 1)) # 
b_left = torch.cat([xb_left, tb_left ],1)
u_b_l = 0*xb_left
u_b_l_xx = 0*xb_left #new



xb_right = pi*torch.ones((right_boundary_pts, 1)) # right spatial boundary
tb_right = torch.rand((right_boundary_pts, 1)) # right boundary pts
b_right = torch.cat([xb_right, tb_right ],1)
u_b_r = 0*xb_right
u_b_r_xx = 0*xb_right #new


# collocation/ interior points
x_interior = pi*torch.rand((residual_pts, 1))
t_interior = torch.rand((residual_pts, 1))
interior = torch.cat([x_interior, t_interior],1)

# Training set
training_set = DataLoader(torch.utils.data.TensorDataset(init, u_init, u_init_t, b_left, b_right), batch_size=2000, shuffle=False)

# neural network Class

class NeuralNet(nn.Module):

    def __init__(self, input_dimension, output_dimension, n_hidden_layers, neurons):
        super(NeuralNet, self).__init__()
        # Number of input dimensions n
        self.input_dimension = input_dimension
        # Number of output dimensions m
        self.output_dimension = output_dimension
        # Number of neurons per layer 
        self.neurons = neurons
        # Number of hidden layers 
        self.n_hidden_layers = n_hidden_layers
        # Activation function 
        self.activation = nn.Tanh()
        
        self.input_layer = nn.Linear(self.input_dimension, self.neurons)
        self.hidden_layers = nn.ModuleList([nn.Linear(self.neurons, self.neurons) for _ in range(n_hidden_layers)])
        self.output_layer = nn.Linear(self.neurons, self.output_dimension)

    def forward(self, x):
        # The forward function performs the set of affine and non-linear transformations defining the network 
        # (see equation above)
        x = self.activation(self.input_layer(x))
        for k, l in enumerate(self.hidden_layers):
            x = self.activation(l(x))
        return self.output_layer(x)
# Model definition
my_network = NeuralNet(input_dimension = init.shape[1], output_dimension = u_init.shape[1], n_hidden_layers=4, neurons=20)

def init_xavier(model, retrain_seed):
    torch.manual_seed(retrain_seed)
    def init_weights(m):
        if type(m) == nn.Linear and m.weight.requires_grad and m.bias.requires_grad:
            g = nn.init.calculate_gain('tanh')
            torch.nn.init.xavier_uniform_(m.weight, gain=g)
            #torch.nn.init.xavier_normal_(m.weight, gain=g)
            m.bias.data.fill_(0)
    model.apply(init_weights)

# Random Seed for weight initialization
retrain = 128
# Xavier weight initialization
init_xavier(my_network, retrain)

if opt_type == "ADAM":
    optimizer_ = optim.Adam(my_network.parameters(), lr=0.001)
elif opt_type == "LBFGS":
    optimizer_ = optim.LBFGS(my_network.parameters(), lr=0.1, max_iter=1, max_eval=50000, tolerance_change=1.0 * np.finfo(float).eps)
else:
    raise ValueError("Optimizer not recognized")
  
def fit(model, training_set, interior, num_epochs, optimizer, p, verbose=True):
    history = list()
    
    # Loop over epochs
    for epoch in range(num_epochs):
        if verbose: print("################################ ", epoch, " ################################")

        running_loss = list([0])
        
        # Loop over batches
        for j, (initial, u_initial, u_initial_t, bd_left, bd_right) in enumerate(training_set):
            
            def closure():
                # zero the parameter gradients
                optimizer.zero_grad()
                
                initial.requires_grad = True
                 # for initial
                u_initial_pred_ = model(initial)
                inputs1 = torch.ones(initial_pts, 1 )
                grad_u_initial = torch.autograd.grad(u_initial_pred_, initial, grad_outputs=inputs1, create_graph=True)[0]
                u_initial_pred_t_ =  grad_u_initial[:, 1]
                
                
                # boundary
                bd_left.requires_grad = True
                bd_right.requires_grad = True
                u_bd_left_pred_ = model(bd_left)
                u_bd_right_pred_ = model(bd_right)
                inputs2 = torch.ones(left_boundary_pts, 1)
                inputs3 = torch.ones(right_boundary_pts, 1)
                grad_u_b_l = torch.autograd.grad(u_bd_left_pred_, bd_left, grad_outputs=inputs2, create_graph=True)[0]
                grad_u_b_r = torch.autograd.grad(u_bd_right_pred_, bd_right, grad_outputs=inputs3, create_graph=True)[0]
                u_b_l_x = grad_u_b_l[:, 0]
                u_b_r_x = grad_u_b_r[:, 0]
                u_b_l_xx = torch.autograd.grad(u_b_l_x, bd_left, grad_outputs=torch.ones(bd_left.shape[0]), create_graph=True)[0]
                u_bd_left_pred_xx_ = u_b_l_xx[:, 0]
                
                u_b_r_xx = torch.autograd.grad(u_b_r_x, bd_right, grad_outputs=torch.ones(bd_right.shape[0]), create_graph=True)[0]
                u_bd_right_pred_xx_ = u_b_r_xx[:, 0]
                 
               
                # residual calculation
                interior.requires_grad = True
                u_hat = model(interior)
                inputs = torch.ones(residual_pts, 1 )
              
                grad_u_hat = torch.autograd.grad(u_hat, interior, grad_outputs=inputs, create_graph=True)[0]
                
                u_x = grad_u_hat[:, 0]
                u_t =  grad_u_hat[:, 1]
                
                grad_grad_u_t = torch.autograd.grad(u_t, interior, grad_outputs=torch.ones(interior.shape[0]), create_graph=True)[0]
                u_tt = grad_grad_u_t[:, 1]
                
                grad_grad_u_x = torch.autograd.grad(u_x, interior, grad_outputs=torch.ones(interior.shape[0]), create_graph=True)[0]
                u_xx = grad_grad_u_x[:, 0]
                
                grad_grad_u_xx = torch.autograd.grad(u_xx, interior, grad_outputs=torch.ones(interior.shape[0]), create_graph=True)[0]
                u_xxx = grad_grad_u_xx[:, 0]
                
                grad_grad_u_xxx = torch.autograd.grad(u_xxx, interior, grad_outputs=torch.ones(interior.shape[0]), create_graph=True)[0]
                u_xxxx = grad_grad_u_xxx[:, 0]
                
            
                
                # Item 1. below
                loss_ini = torch.mean((u_initial_pred_.reshape(-1, ) - u_initial.reshape(-1, ))**p) + torch.mean((u_initial_pred_t_.reshape(-1, ) - u_initial_t.reshape(-1, ))**p)
                
                loss_p = torch.mean((u_tt.reshape(-1, ) + u_xxxx.reshape(-1, )-(1 -16*pi**2)*torch.sin(interior[:, 0])*torch.cos(4*pi*interior[:, 1]))**p)
                
                loss_bd = torch.mean((u_bd_left_pred_.reshape(-1,))**p) + torch.mean((u_bd_right_pred_.reshape(-1,))**p)+torch.mean((u_bd_left_pred_xx_.reshape(-1,))**p) + torch.mean((u_bd_right_pred_xx_.reshape(-1,))**p)
                
                loss = loss_ini + 0.1*loss_p + loss_bd
               
                # Item 2. below
                loss.backward()
                # Compute average training loss over batches for the current epoch
                running_loss[0] += loss.item()
                return loss
            
            # Item 3. below
            optimizer.step(closure=closure)
            
        print('Loss: ', (running_loss[0] / len(training_set)))
        history.append(running_loss[0])

    return history

#n_epochs = 15000
#start_time = time.time()
#history = fit(my_network, training_set, interior, n_epochs, optimizer_, p=2, verbose=True )
#end_time = time.time()
#total_time = end_time - start_time
#print("Training time: {:.2f} seconds".format(total_time))

### RMSE and Realtive error

# Number of points for x and t
num_points_x = 256
num_points_t = 100

# Create uniformly spaced values for x and t
x_test = torch.linspace(0, pi, num_points_x).reshape(-1, 1)
t_test = torch.linspace(0, 1, num_points_t).reshape(-1, 1)

# Create a grid of x and t values
x_grid, t_grid = torch.meshgrid(x_test.squeeze(), t_test.squeeze(), indexing='ij')

# Reshape the grids to create the test tensor
x_flat = x_grid.reshape(-1, 1)
t_flat = t_grid.reshape(-1, 1)
test = torch.cat([x_flat, t_flat], dim=1)

my_network.load_state_dict(torch.load('Trained_Models/model3.pth'))
my_network = my_network.cpu()
u_test_pred = my_network(test).reshape(-1,1)

# torch.save(my_network.state_dict(), 'Trained_Models/model3.pth')

print(x_flat.shape)
u_exact = exact_solution(x_flat, t_flat)

# Compute the relative L2 error norm (generalization error)

RMSE = np.sqrt(mean_squared_error(u_test_pred.detach().numpy(), u_exact.detach().numpy()))
print("RMSE Test: ", RMSE)

u_test_pred = u_test_pred.reshape(-1, )
u_exact = u_exact.reshape(-1, )

relative_error_test = torch.sqrt(torch.mean((u_test_pred - u_exact)**2) / torch.mean(u_exact**2))
print("Relative Error Test: ", relative_error_test.detach().numpy())

# Function to append result to CSV file
def append_to_csv(file_name, result):
    file_exists = os.path.isfile(file_name)
    
    with open(file_name, mode='a', newline='') as file:
        writer = csv.writer(file)
        
        # Write header only if file does not exist
        if not file_exists:
            writer.writerow(['Result'])
        
        # Append the result
        writer.writerow([result])
        
# File names for storing results
RMSE_file = 'Results/RMSE_results.csv'
Rel_err_file = 'Results/Rel_err_results.csv'
Time_file = 'Results/Time_results.csv'

# Append results to respective CSV files
#append_to_csv(RMSE_file, RMSE)
#append_to_csv(Rel_err_file, relative_error_test.detach().numpy())
#append_to_csv(Time_file, total_time)

# Detach and convert tensors to numpy arrays
u_test_pred = u_test_pred.detach().numpy()
x_test = x_test.detach().numpy()
t_test = t_test.detach().numpy()
u_exact = u_exact.detach().numpy()

# Compute the absolute error
abs_error = np.abs(u_test_pred - u_exact)
abs_error = abs_error.reshape(256, 100)


fontsize = 24
aspect = 0.09
# visualize the solution
fig, ax = plt.subplots(1, 1, figsize=(6, 5), constrained_layout=True)
extent = [t_test.min(), t_test.max(), x_test.min(), x_test.max()]

sol_img1 = ax.imshow(abs_error, extent=extent, origin='lower', aspect=aspect, cmap='jet')

# Colorbar
cb = fig.colorbar(sol_img1, ax=ax, location='bottom', aspect=20)

# Set the formatter for the colorbar to scientific notation
cb.formatter = ticker.FuncFormatter(lambda x, pos: f'{x:.1e}')
cb.update_ticks()

# Set only three ticks on the colorbar
cb.set_ticks([abs_error.min(), (abs_error.min() + abs_error.max()) / 2, abs_error.max()])

# Set the font size of the colorbar labels
cb.ax.tick_params(labelsize=fontsize)

# Set axis labels
ax.set_xlabel('t', fontsize=fontsize)
ax.set_ylabel('x', fontsize=fontsize)

# Set specific ticks on the x-axis
ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1])

# Set tick parameters for the axes
plt.tick_params(axis='both', labelsize=fontsize)

# Save the figure
plt.savefig('Results/Figures/Absolute_error.pdf')
plt.show()
