import copy
import os
import torch
import numpy as np

os.environ["DDEBACKEND"] = "pytorch"
import deepxde as dde
from deepxde.callbacks import Callback

from src import Wave1D_CRunner
from utils.mlp import MLP
from scripts.error_landscape_utils import *


class PCPINNRunner(Wave1D_CRunner):
    def __init__(self, device="cuda") -> None:
        super().__init__(device=device)
        self.weights = []

    def test(self):
        self.model.eval()
        X = torch.tensor(self.test_X,
            dtype=torch.float32, device=self.device)
        pred_Y = self.model(X).cpu().detach().numpy()
        mae = np.abs(pred_Y - self.test_Y).mean()
        mse = ((pred_Y - self.test_Y)**2).mean()
        l1re = mae / np.abs(self.test_Y).mean()
        l2re = np.sqrt(mse) / np.sqrt((self.test_Y**2).mean())

        self.weights.append(copy.deepcopy(
            copy.deepcopy(get_weights(self.model))
        ))

        return mae, mse, l1re, l2re
    
    def eval_loss(self):
        self.model.to(self.device)
        self.model.eval()
        loss = self.loss_fn()
        loss_val = loss.item()

        return loss_val
    
    def eval_error(self):
        self.model.to(self.device)
        error = self.test()[-1]
        return error



class PINNRunner(Wave1D_CRunner):
    def __init__(self, device="cuda") -> None:
        self.weights = []
        super().__init__(device=device)

    def init_model(self):
        self.model = MLP(
            layer_sizes=[self.config["Spatial Temporal Dimension"]] + 
            [100] * 5 + [self.config["Output Dimension"]]).to(self.device)

    def setup(self):
        C=2; scale=1; a=4
        self.pde_data.setup(self.config)
        self.init_model()
        
        # output dim
        self.output_dim = 1
        # geom
        self.bbox = [0, scale, 0, scale]
        self.geom = dde.geometry.Rectangle(
            xmin=[self.bbox[0], self.bbox[2]], 
            xmax=[self.bbox[1], self.bbox[3]])

        # define PDE
        def wave_pde(x, u):
            u_xx = dde.grad.hessian(u, x, i=0, j=0)
            u_tt = dde.grad.hessian(u, x, i=1, j=1)

            return u_tt - C**2 * u_xx

        def ref_sol(x):
            x = x / scale
            return (np.sin(np.pi * x[:, 0:1]) * np.cos(2 * np.pi * x[:, 1:2]) + 0.5 * np.sin(a * np.pi * x[:, 0:1]) * np.cos(2 * a * np.pi * x[:, 1:2]))

        self.ref_sol = ref_sol

        def boundary_x0(x, on_boundary):
            return on_boundary and (np.isclose(x[0], self.bbox[0]) or np.isclose(x[0], self.bbox[1]))

        def boundary_t0(x, on_boundary):
            return on_boundary and np.isclose(x[1], self.bbox[2])
        
        bc1 = dde.NeumannBC(
            self.geom, lambda _: 0, boundary_t0, component=0)
        bc2 = dde.DirichletBC(
            self.geom, ref_sol, boundary_t0, component=0)
        bc3 = dde.DirichletBC(
            self.geom, lambda _: 0, boundary_x0, component=0)
        bcs = [bc1, bc2, bc3]

        # Note: test_X, test_Y will be numpy arrays instead of tensors.
        self.test_X, self.test_Y = self.pde_data.get_test_data()

        # Prepare trainer
        data = dde.data.PDE(
            self.geom,
            wave_pde,
            bcs,
            num_domain=8192,
            num_boundary=2048
        )
        self.model.regularizer = None
        self.dde_model = dde.Model(data, self.model)
        self.dde_model.compile("adam", lr=self.config["Learning Rate"])
        self.callback = PINNTracker(self)

    def test(self):
        self.model.eval()
        X = torch.tensor(self.test_X,
            dtype=torch.float32, device=self.device)
        pred_Y = self.model(X).cpu().detach().numpy()
        mae = np.abs(pred_Y - self.test_Y).mean()
        mse = ((pred_Y - self.test_Y)**2).mean()
        l1re = mae / np.abs(self.test_Y).mean()
        l2re = np.sqrt(mse) / np.sqrt((self.test_Y**2).mean())

        self.weights.append(copy.deepcopy(
            copy.deepcopy(get_weights(self.model))
        ))
        self.model.train()

        return mae, mse, l1re, l2re
    
    def run(self):
        self.dde_model.train(epochs=self.config["Iterations"],
            callbacks=[self.callback], 
            display_every=self.config["Log Interval"])
        
        self.logger.loss_history = self.callback.loss_history
        self.logger.l2re_history = self.callback.l2re_history

    def eval_loss(self):
        self.model.to(self.device)
        self.model.eval()
        _, loss_val = self.dde_model._outputs_losses(
            False,
            self.dde_model.train_state.X_train,
            self.dde_model.train_state.y_train,
            self.dde_model.train_state.train_aux_vars,
        )
        loss_val = np.sum(loss_val)

        return loss_val
    
    def eval_error(self):
        self.model.to(self.device)
        error = self.test()[-1]
        return error


class PINNTracker(Callback):
    def __init__(self, runner: PINNRunner):
        super().__init__()
        self.runner = runner
        self.loss_history = []
        self.l2re_history = []
    
    def reset(self):
        self.loss_history = []
        self.l2re_history = []

    def on_epoch_end(self):
        _, _, _, l2re = self.runner.test()
        self.l2re_history.append(l2re)
        self.loss_history.append(np.sum(self.model.train_state.loss_train))
