import time
import numpy as np
import torch

from src.pdedata import PDEDataBase
from src.logger import Logger
from utils.utils import scipy_to_torch_sparse
from utils.fourier_mlp import FourierMLP


class RunnerBase:
    """
    Base class for running the solver 
    (training and testing).
    """
    def __init__(self, device="cuda") -> None:
        self.config: dict | None = None
        self.pde_data: PDEDataBase | None = None
        self.logger = Logger()
        self.model = None
        self.device = device

    def load_config(self):
        """
        Load the configuration from file.
        """
        raise NotImplementedError("load_config() is not implemented.")

    def get_config(self):
        """
        Return the configuration.
        """
        return self.config
    
    def set_config(self, config: dict):
        """
        Set the configuration.
        """
        self.config = config

    def init_model(self):
        """
        Initialize the model.
        """
        if self.config["Spatial Temporal Dimension"] <= 2:
            self.model = FourierMLP(
                input_dim=self.config["Spatial Temporal Dimension"],
                output_dim=self.config["Output Dimension"]).to(self.device)
        else:
            self.model = FourierMLP(
                input_dim=self.config["Spatial Temporal Dimension"],
                output_dim=self.config["Output Dimension"],
                n_layers=5, n_hidden=128).to(self.device)

    def setup(self):
        """
        Setup the runner.
        """
        print("Start setting up the runner...")
        self.logger.setup(self.config)
        self.logger.log_config()
        self.pde_data.setup(self.config)
        if self.config["Use Preconditioner"]:
            self.logger.log_precond_time(self.pde_data.precond_time)
        if self.config["Log Precondition Quality"]:
            self.logger.log_original_cond(self.pde_data.original_cond)
            if self.config["Use Preconditioner"]:
                self.logger.log_precond_res(
                    self.pde_data.precond_error,
                    self.pde_data.after_cond)
        self.init_model()

        if self.config["Use Sparse Solver"]:
            self.train = self.train_sparse
        else:
            self.train = self.train_dense

        self.X = torch.tensor(self.pde_data.get_coords(),
            dtype=torch.float32, device=self.device)
        self.A = self.pde_data.get_matrix()
        self.b = torch.tensor(self.pde_data.get_rhs(),
            dtype=torch.float32, device=self.device)
        self.precond = self.pde_data.get_precond()

        if self.config["Use Sparse Solver"]:
            # Convert A, precond to sparse matrix
            self.A = scipy_to_torch_sparse(self.A, self.device)
            self.precond = [
                scipy_to_torch_sparse(M, self.device)
                for M in self.precond]
        else:
            # Convert A, precond to dense matrix
            self.A = torch.tensor(self.A.todense(),
                dtype=torch.float32, device=self.device)
            self.precond = [
                torch.tensor(M.todense(), 
                    dtype=torch.float32, device=self.device)
                for M in self.precond]

        # Note: test_X, test_Y will be numpy arrays instead of tensors.
        self.test_X, self.test_Y = self.pde_data.get_test_data()

        print("Finish setting up the runner.")

    def train_sparse(self):
        """
        Train the model (using sparse matrix).
        If the grid size is very small, this method will
        be significantly slower than train_dense().
        """
        self.model.train()
        self.optimizer.zero_grad()

        if self.config["Use Preconditioner"]:
            Pr, L, U = self.precond
        pred_Y = self.model(self.X)
        if self.config["Output Dimension"] > 1:
            pred_Y = self.pde_data.rearrange_u_data(pred_Y)

        with torch.no_grad():
            # [Compute r := A @ x - b]
            r = torch.sparse.mm(self.A, pred_Y) - self.b

            # [Loss value: 1/n * ||r||^2]
            loss_val = torch.mean(r**2).item()

            # [Compute r := P^{-1} @ r]
            if self.config["Use Preconditioner"]:
                # r := Pr @ r
                r = torch.sparse.mm(Pr, r)
                # r := L^{-1} @ r
                r = torch.triangular_solve(r, L, upper=False).solution
                # r := U^{-1} @ r
                r = torch.triangular_solve(r, U, upper=True).solution

                # [Compute r := A^{T} @ (P^{-1})^{T} @ r]
                # r := U^{-1}^{T} @ r
                r = torch.triangular_solve(r, U, upper=True, transpose=True).solution
                # r := L^{-1}^{T} @ r
                r = torch.triangular_solve(r, L, upper=False, transpose=True).solution
                # r := Pr^{T} @ r
                r = torch.sparse.mm(Pr.t(), r)
            # r := (A^{T} @ r) / n
            r = torch.sparse.mm(self.A.t(), r) / r.shape[0]

        # Vector-Jacobian product r^{T} @ (\partial x / \partial w)
        pred_Y.backward(r)
        self.optimizer.step()

        return loss_val

    def loss_fn(self):
        """
        Return the loss function.
        """
        pred_Y = self.model(self.X)
        if self.config["Output Dimension"] > 1:
            pred_Y = self.pde_data.rearrange_u_data(pred_Y)
        # r := Ax - b
        r = torch.mm(self.A, pred_Y) - self.b
        if self.config["Use Preconditioner"]:
            Pr, L, U = self.precond
            # r := Pr @ r
            r = torch.mm(Pr, r)
            # r := L^{-1} @ r
            r = torch.linalg.solve_triangular(L, r, upper=False)
            # r := U^{-1} @ r
            r = torch.linalg.solve_triangular(U, r, upper=True)
        # r := 1/n * ||r||^2
        r = torch.mean(r**2)

        return r

    def train_dense(self):
        """
        Train the model (using dense matrix).
        If the grid size is too large, this method will
        consume too much memory. Consider using train_sparse()
        instead.
        """
        self.model.train()
        self.optimizer.zero_grad()
        loss = self.loss_fn()
        loss_val = loss.item()
        loss.backward()
        self.optimizer.step()

        return loss_val

    def test(self):
        """
        Test the model.
        """
        self.model.eval()
        X = torch.tensor(self.test_X,
            dtype=torch.float32, device=self.device)
        pred_Y = self.model(X).cpu().detach().numpy()
        mae = np.abs(pred_Y - self.test_Y).mean()
        mse = ((pred_Y - self.test_Y)**2).mean()
        l1re = mae / np.abs(self.test_Y).mean()
        l2re = np.sqrt(mse) / np.sqrt((self.test_Y**2).mean())

        return mae, mse, l1re, l2re

    def run(self):
        """
        Run the solver.
        """
        self.optimizer = torch.optim.Adam(
            self.model.parameters(), lr=self.config["Learning Rate"])
        print("Start Training...")
        self.logger.log_train_start()
        t1 = time.time()
        for self.epoch in range(self.config["Iterations"]):
            self.loss_val = self.train()
            if self.epoch % self.config["Log Interval"] == 0 or \
                self.epoch == self.config["Iterations"] - 1:
                mae, mse, l1re, l2re = self.test()
                self.logger.log_train(self.epoch, self.loss_val, mae, 
                    mse, l1re, l2re, time.time() - t1)
        print("Finish Training.")
        self.logger.log_train_end(time.time() - t1)
        self.logger.close()
        self.logger.save_history_data()
        self.logger.save_prediction(
            self.test_X, self.test_Y, self.model)
        self.logger.save_history_plot()
        self.logger.save_model(self.model)
