import torch
import torch.nn as nn
import torch.nn.init as init
import numpy as np
import matplotlib.pyplot as plt
import math
import os
import sys
from system import *
from utils import *
from config import *

sys.path.append("..")

args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class PINN(nn.Module):
    def __init__(self):
        super(PINN, self).__init__()
        self.fc1 = nn.Linear(2, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 64)
        self.fc4 = nn.Linear(64, 1)
        self.activation = nn.Tanh()

        self._initialize_weights()

    def _initialize_weights(self):
        init.xavier_uniform_(self.fc1.weight)
        init.xavier_uniform_(self.fc2.weight)
        init.xavier_uniform_(self.fc3.weight)
        init.xavier_uniform_(self.fc4.weight)

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.activation(self.fc3(x))
        x = self.fc4(x)
        return x


class NaivePINN:
    def __init__(self, beta=0, nu=0, rho=0, epsilon=0, theta=0):
        self.model = PINN().to(device)
        
        self.beta = beta
        self.nu = nu
        self.rho = rho
        self.epsilon = epsilon
        self.theta = theta
        
        # Set optimizer to LBFGS
        self.optimizer = torch.optim.LBFGS(self.model.parameters(), lr=1e-02, max_iter=100, tolerance_grad=1e-5)
        
    def residual_loss(self, x, t):
        x = x.to(device)
        t = t.to(device)
        u = t*self.model(torch.cat((x, t), dim=1))+1+torch.sin(x)

        u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True)[0]
        u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x), create_graph=True)[0]
        u_t = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), create_graph=True)[0]

        residual = (
            self.beta * u_x
            - self.nu * u_xx
            - self.rho * u * (1 - u)
            + self.epsilon * (u**3 - u)
            - self.theta * (u**2 - u**3)
            + u_t
        )
        return torch.mean(residual ** 2)

    def initial_loss(self, x, t):
        x = x.to(device)
        t = t.to(device)
        init_func = function(args.u0_str)
        u_initial = init_func(x).to(device, dtype=torch.float)
        u_pred = self.model(torch.cat((x, t), dim=1))
        return torch.mean((u_pred - u_initial)**2)
    
    def boundary_loss(self, x, t):
        
        b_out = self.model(torch.cat((x, t), dim=1))
        min_condition = (x[:, 0] == 0)
        max_condition = (x[:, 0] == 2*math.pi)
        min_indices = torch.nonzero(min_condition, as_tuple=True)[0]
        max_indices = torch.nonzero(max_condition, as_tuple=True)[0]
        return torch.mean((b_out[min_indices] - b_out[max_indices]) ** 2)

    def loss_function(self, x_res, t_res, x_i, t_i, x_b, t_b):
      
        loss_residual = self.residual_loss(x_res, t_res)
        loss_initial = self.initial_loss(x_i, t_i)
        loss_boundary = self.boundary_loss(x_b, t_b)

        return loss_residual + loss_initial + loss_boundary

    def train(self, domain_list, epochs):
        
        batch_indice = torch.randperm(domain_list[1].shape[0])[:1000]
        x_res, t_res = domain_list[1][batch_indice, 0:1].to(device, dtype=torch.float).requires_grad_(True), \
            domain_list[1][batch_indice, 1:2].to(device, dtype=torch.float).requires_grad_(True)
        x_i, t_i = domain_list[2][:, 0:1].to(device, dtype=torch.float), domain_list[2][:, 1:2].to(device, dtype=torch.float)
        x_b, t_b = domain_list[3][:, 0:1].to(device, dtype=torch.float), domain_list[3][:, 1:2].to(device, dtype=torch.float)
        

        for epoch in range(epochs):
            # Closure function for LBFGS optimizer
            def closure():
                self.optimizer.zero_grad()
                loss = self.loss_function(x_res, t_res, x_i, t_i, x_b, t_b)
                loss.backward(retain_graph=True)
                return loss

            
            self.optimizer.step(closure)

            if epoch % 10 == 0:
                loss = self.loss_function(x_res, t_res, x_i, t_i, x_b, t_b)
                if torch.isnan(loss):
                    print(f'Epoch {epoch}, Loss is NaN. Returning -1.')
                    return -1
                print(f'Epoch {epoch}, Loss: {loss.item()}')
                if (loss < args.threshold):
                    return 1
        return 1

    def predict(self, x, t):
        with torch.no_grad():
            x, t = x.to(device), t.to(device)
            x_t = torch.cat((x, t), dim=1)
            return t*self.model(x_t)+1+torch.sin(x)


