import time
import torch

from src.runner import RunnerBase
from utils.utils import scipy_to_torch_sparse


class NonLinearPDERunner(RunnerBase):
    """
    Runner class for the non-linear PDEs.
    """
    def setup(self):
        print("Start setting up the runner...")
        self.logger.setup(self.config)
        self.logger.log_config()
        self.pde_data.setup(self.config)
        self.init_model()

        if self.config["Use Sparse Solver"]:
            self.train = self.train_sparse
        else:
            self.train = self.train_dense

        self.X = torch.tensor(self.pde_data.get_coords(),
            dtype=torch.float32, device=self.device)

        # Note: test_X, test_Y will numpy arrays instead of tensors.
        self.test_X, self.test_Y = self.pde_data.get_test_data()

        print("Finish setting up the runner.")
    
    def setup_newton(self):
        """
        Setup for the Newton iteration.
        """
        u_data = self.model(self.X).cpu().detach().numpy()
        self.pde_data.setup_newton(u_data)

        if self.config["Use Preconditioner"]:
            self.logger.log_precond_time(self.pde_data.precond_time)
        if self.config["Log Precondition Quality"]:
            self.logger.log_original_cond(self.pde_data.original_cond)
            if self.config["Use Preconditioner"]:
                self.logger.log_precond_res(
                    self.pde_data.precond_error,
                    self.pde_data.after_cond)

        self.A = self.pde_data.get_matrix()
        self.b = torch.tensor(self.pde_data.get_rhs(),
            dtype=torch.float32, device=self.device)
        self.precond = self.pde_data.get_precond()

        if self.config["Use Sparse Solver"]:
            # Convert A, precond to sparse matrix
            self.A = scipy_to_torch_sparse(self.A, self.device)
            self.precond = [
                scipy_to_torch_sparse(M, self.device)
                for M in self.precond]
        else:
            # Convert A, precond to dense matrix
            self.A = torch.tensor(self.A.todense(),
                dtype=torch.float32, device=self.device)
            self.precond = [
                torch.tensor(M.todense(), 
                    dtype=torch.float32, device=self.device)
                for M in self.precond]

    def run(self):
        self.optimizer = torch.optim.Adam(
            self.model.parameters(), lr=self.config["Learning Rate"])
        print("Start Training...")
        self.logger.log_train_start()
        t1 = time.time()
        for self.epoch in range(self.config["Iterations"]):
            if self.epoch % self.config["Newton Interval"] == 0:
                print("Start Newton iteration...")
                self.logger.log_newton_update()
                self.setup_newton()
            self.loss_val = self.train()
            if self.epoch % self.config["Log Interval"] == 0 or \
                self.epoch == self.config["Iterations"] - 1:
                mae, mse, l1re, l2re = self.test()
                self.logger.log_train(self.epoch, self.loss_val, mae, 
                    mse, l1re, l2re, time.time() - t1)
        print("Finish Training.")
        self.logger.log_train_end(time.time() - t1)
        self.logger.close()
        self.logger.save_history_data()
        self.logger.save_prediction(
            self.test_X, self.test_Y, self.model)
        self.logger.save_history_plot()
        self.logger.save_model(self.model)
