import torch
import torch.nn as nn
import torch.nn.init as init
import numpy as np
import matplotlib.pyplot as plt
import math
import os
import sys
from utils.utils import *

from data_gen.utils.dataset_util_2D import *

sys.path.append("..")

args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if args.spatial_size == "pi":
    SPATIAL_SIZE = math.pi
elif args.spatial_size == "1":
    SPATIAL_SIZE = 1
else:
    SPATIAL_SIZE = 2*math.pi

class PINN(nn.Module):
    def __init__(self):
        super(PINN, self).__init__()
        self.fc1 = nn.Linear(3, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 128)      # Add one layer and larger hidden size
        self.fc4 = nn.Linear(128, 128)      # Add one layer and larger hidden size
        self.fc5 = nn.Linear(128, 128)      # Add one layer and larger hidden size
        self.fc6 = nn.Linear(128, 128)      # Add one layer and larger hidden size
        # self.fc7 = nn.Linear(128, 128)      # Add one layer and larger hidden size
        self.fc8 = nn.Linear(128, 128)      # Add one layer and larger hidden size
        self.fc9 = nn.Linear(128, 1)
        self.activation = nn.Tanh()

        self._initialize_weights()

    def _initialize_weights(self):
        init.xavier_uniform_(self.fc1.weight)
        init.xavier_uniform_(self.fc2.weight)
        init.xavier_uniform_(self.fc3.weight)
        init.xavier_uniform_(self.fc4.weight)
        init.xavier_uniform_(self.fc5.weight)
        init.xavier_uniform_(self.fc6.weight)
        # init.xavier_uniform_(self.fc7.weight)
        init.xavier_uniform_(self.fc8.weight)
        init.xavier_uniform_(self.fc9.weight)

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.activation(self.fc3(x))
        x = self.activation(self.fc4(x))
        x = self.activation(self.fc5(x))
        x = self.activation(self.fc6(x))
        # x = self.activation(self.fc7(x))
        x = self.activation(self.fc8(x))
        x = self.fc9(x)
        return x


class NaivePINN:
    def __init__(self, beta=0, beta_y=0, nu=0, nu_y=0, rho=0, epsilon=0, theta=0):
        self.model = PINN().to(device)
        
        self.beta = beta
        self.beta_y = beta_y
        self.nu = nu
        self.nu_y = nu_y
        self.rho = rho
        self.epsilon = epsilon
        self.theta = theta
        
        # Set optimizer to LBFGS
        self.optimizer = torch.optim.LBFGS(self.model.parameters(), lr=1e-02, max_iter=100, tolerance_grad=1e-5)
        
    def residual_loss(self, x, y, t):
        x = x.to(device)
        y = y.to(device)
        t = t.to(device)
        u = self.model(torch.cat((x, y, t), dim=1))

        u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True)[0]
        u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x), create_graph=True)[0]
        u_y = torch.autograd.grad(u, y, grad_outputs=torch.ones_like(u), create_graph=True)[0]
        u_yy = torch.autograd.grad(u_y, y, grad_outputs=torch.ones_like(u_y), create_graph=True)[0]
        u_t = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), create_graph=True)[0]

        residual = (self.beta * u_x + self.beta_y * u_y     # Convection terms
            - self.nu * u_xx - self.nu_y * u_yy               # Diffusion terms
            - self.rho * u * (1 - u)                        # Reaction term 1
            - self.epsilon * (u - u**3)                     # Reaction term 2
            - self.theta * (u**2 - u**3)                    # Reaction term 3
            + u_t
        )
        return torch.mean(residual ** 2)

    def initial_loss(self, x, y, t):
        x = x.to(device)
        y = y.to(device)
        t = t.to(device)
        init_func = function('sin')
        u_initial = init_func(x, y).to(device, dtype=torch.float)
        u_pred = self.model(torch.cat((x, y, t), dim=1))
        return torch.mean((u_pred - u_initial)**2)
    
    def boundary_loss(self, x, y, t):
        
        b_out = self.model(torch.cat((x, y, t), dim=1))
        min_condition_x = (x == 0)
        max_condition_x = (x == SPATIAL_SIZE)
        min_indices_x = torch.nonzero(min_condition_x, as_tuple=True)[0]
        max_indices_x = torch.nonzero(max_condition_x, as_tuple=True)[0]
        min_condition_y = (y == 0)
        max_condition_y = (y == SPATIAL_SIZE)
        min_indices_y = torch.nonzero(min_condition_y, as_tuple=True)[0]
        max_indices_y = torch.nonzero(max_condition_y, as_tuple=True)[0]
        
        return torch.mean((b_out[min_indices_x] - b_out[max_indices_x])**2 + (b_out[min_indices_y] - b_out[max_indices_y]) ** 2)

    def loss_function(self, x_res, y_res, t_res, x_i, y_i, t_i, x_b, y_b, t_b):
      
        loss_residual = self.residual_loss(x_res, y_res, t_res)
        loss_initial = self.initial_loss(x_i, y_i, t_i)
        loss_boundary = self.boundary_loss(x_b, y_b, t_b)

        return loss_residual + loss_initial + loss_boundary

    def train(self, domain_list, epochs):
        
        batch_indice = torch.randperm(domain_list[1].shape[0])[:args.PINN_batch]
        x_res, y_res, t_res = domain_list[1][batch_indice, 0:1].to(device, dtype=torch.float).requires_grad_(True), \
                            domain_list[1][batch_indice, 1:2].to(device, dtype=torch.float).requires_grad_(True), \
                            domain_list[1][batch_indice, 2:3].to(device, dtype=torch.float).requires_grad_(True)
        x_i, y_i, t_i = domain_list[2][:, 0:1].to(device, dtype=torch.float), domain_list[2][:, 1:2].to(device, dtype=torch.float), \
                        domain_list[2][:, 2:3].to(device, dtype=torch.float)
        x_b, y_b, t_b = domain_list[3][:, 0:1].to(device, dtype=torch.float), domain_list[3][:, 1:2].to(device, dtype=torch.float), \
                        domain_list[3][:, 2:3].to(device, dtype=torch.float)
        

        for epoch in range(epochs):
            # Closure function for LBFGS optimizer
            def closure():
                self.optimizer.zero_grad()
                loss = self.loss_function(x_res, y_res, t_res, x_i, y_i, t_i, x_b, y_b, t_b)
                loss.backward(retain_graph=True)
                return loss

            
            self.optimizer.step(closure)

            if epoch % 10 == 0:
                loss = self.loss_function(x_res, y_res, t_res, x_i, y_i, t_i, x_b, y_b, t_b)
                if torch.isnan(loss):
                    print(f'Epoch {epoch}, Loss is NaN. Returning -1.')
                    return -1
                print(f'Epoch {epoch}, Loss: {loss.item()}')
                if (loss < args.threshold):
                    return 1
        return 1

    def predict(self, x, y, t):
        with torch.no_grad():
            x, y, t = x.to(device), y.to(device), t.to(device)
            x_y_t = torch.cat((x, y, t), dim=1)
            return self.model(x_y_t)