import torch
import torch.nn as nn
import torch.nn.init as init
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class PINN(nn.Module):
    def __init__(self):
        super(PINN, self).__init__()
        self.fc1 = nn.Linear(2, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 64)
        self.fc4 = nn.Linear(64, 1)
        self.activation = nn.Tanh()

        self._initialize_weights()

    def _initialize_weights(self):
        init.xavier_uniform_(self.fc1.weight)
        init.xavier_uniform_(self.fc2.weight)
        init.xavier_uniform_(self.fc3.weight)
        init.xavier_uniform_(self.fc4.weight)

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.activation(self.fc3(x))
        x = self.fc4(x)
        return x


class AdaptivePINN:
    def __init__(self, beta=0, nu=0, rho=0, epsilon=0, theta=0):
        self.model = PINN().to(device)
        
        self.beta = beta
        self.nu = nu
        self.rho = rho
        self.epsilon = epsilon
        self.theta = theta
        
        # Self-adaptive weights for residual and boundary points
        self.lambda_r = torch.ones(1000, requires_grad=True, device=device)
        self.lambda_b = torch.ones(256, requires_grad=True, device=device)
        
        # Set optimizer to LBFGS
        self.optimizer = torch.optim.LBFGS(self.model.parameters(), lr=1e-02, max_iter=100, tolerance_grad=1e-5)
        self.lambda_optimizer = torch.optim.Adam([self.lambda_r, self.lambda_b], lr=1e-2)

    def residual_loss(self, x, t):
        x = x.to(device)
        t = t.to(device)
        u = t*self.model(torch.cat((x, t), dim=1))+1+torch.sin(x)

        u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True)[0]
        u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x), create_graph=True)[0]
        u_t = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), create_graph=True)[0]

        residual = (
            self.beta * u_x
            - self.nu * u_xx
            - self.rho * u * (1 - u)
            + self.epsilon * (u**3 - u)
            - self.theta * (u**2 - u**3)
            + u_t
        )
        return residual

    def boundary_loss(self, x, t, u_boundary):
        x = x.to(device)
        t = t.to(device)
        u_boundary = u_boundary.to(device)
        u_pred = self.model(torch.cat((x, t), dim=1))
        return (u_pred - u_boundary)**2

    def loss_function(self, x_res, t_res, x_b, t_b, u_b):
      
        residual = self.residual_loss(x_res, t_res)
        loss_r = torch.mean(torch.exp(self.lambda_r) * residual**2)

        loss_for_plot = torch.mean(residual**2)

        loss_b = self.boundary_loss(x_b, t_b, u_b)
        loss_b = torch.mean(torch.exp(self.lambda_b) * loss_b)

        # Update self-adaptive weights to maximize the loss
        lambda_loss = -torch.mean(torch.log(self.lambda_r + 1e-6)) - torch.mean(torch.log(self.lambda_b + 1e-6))
        
        #print(loss_r.item(), loss_b.item(), lambda_loss.item())

        return loss_r + loss_b + lambda_loss, loss_for_plot

    def train(self, x_res, t_res, x_b, t_b, u_b, epochs=5000):
        x_res, t_res, x_b, t_b, u_b = x_res.to(device), t_res.to(device), x_b.to(device), t_b.to(device), u_b.to(device)

        print(x_res.shape)        

        for epoch in range(epochs):
            # Closure function for LBFGS optimizer
            def closure():
                self.optimizer.zero_grad()
                self.lambda_optimizer.zero_grad()  # Zero the gradients of the adaptive weights
                loss, loss_plot = self.loss_function(x_res, t_res, x_b, t_b, u_b)
                loss.backward(retain_graph=True)
                return loss

            
            self.optimizer.step(closure)
            # Update self-adaptive weights
            self.lambda_optimizer.step()

            if epoch % 10 == 0:  # Print loss every 10 epochs
                loss, loss_plot = self.loss_function(x_res, t_res, x_b, t_b, u_b)
                print(f'Epoch {epoch}, Loss: {loss_plot.item()}')

    def predict(self, x, t):
        with torch.no_grad():
            x, t = x.to(device), t.to(device)
            x_t = torch.cat((x, t), dim=1)
            return t*self.model(x_t)+1+torch.sin(x)



def generate_data():
    # Residual points
    x_res = torch.rand(1000, 1, requires_grad=True) * 2*np.pi  # x in [-1, 1]
    t_res = torch.rand(1000, 1, requires_grad=True) * 1      # t in [0, 1]

    # Boundary points (initial condition at t=0)
    x_b = torch.linspace(0, 2*np.pi, 100).unsqueeze(1)  # x in [-1, 1]
    t_b = torch.zeros(100, 1)  # t = 0 for all boundary points
    u_b = 1 + torch.sin(x_b)  # Initial condition u(x, 0) = 1 + sin(pi * x)
    

    return x_res, t_res, x_b, t_b, u_b


# Plot the predicted solution
def plot_solution(pinn, num_points=100):
    # Create meshgrid for x and t
    x = torch.linspace(0, 2*np.pi, 256).unsqueeze(1).to(device)
    t = torch.linspace(0, 1, 100).unsqueeze(1).to(device)
    X, T = torch.meshgrid(x.squeeze(), t.squeeze())
    X_flat = X.reshape(-1, 1)
    T_flat = T.reshape(-1, 1)

    # Predict the solution
    u_pred = pinn.predict(X_flat, T_flat)
    
    u_pred = u_pred.reshape(num_points, num_points).cpu().numpy()

    # Plot the solution
    plt.figure(figsize=(5, 5))
    plt.contourf(X.cpu().numpy(), T.cpu().numpy(), u_pred, levels=100, cmap='jet')
    plt.colorbar(label='u(x,t)')
    plt.xlabel('x')
    plt.ylabel('t')
    plt.title('Predicted Solution of Custom PDE')
    plt.show()


if __name__ == "__main__":
    cases = {
        'a': {'beta': 0, 'nu': 0, 'rho': 1, 'epsilon': 1, 'theta': 0},
        'b': {'beta': 0, 'nu': 0, 'rho': 0, 'epsilon': 1, 'theta': 1},
        'c': {'beta': 0, 'nu': 0, 'rho': 1, 'epsilon': 0, 'theta': 1},
    }

    for case, params in cases.items():
        print(f"Training case {case} with params: {params}")
        pinn = AdaptivePINN(**params)
        x_res, t_res, x_b, t_b, u_b = generate_data()
        pinn.train(x_res, t_res, x_b, t_b, u_b, epochs=100)
        plot_solution(pinn)

