import time
import numpy as np
import torch

from src.runner import RunnerBase


class InversePDERunner(RunnerBase):
    """
    Runner class for the inverse problems.
    """
    def init_model(self):
        raise NotImplementedError("init_model() is not implemented.")

    def setup(self):
        print("Start setting up the runner...")
        self.logger.setup(self.config)
        self.logger.log_config()
        self.pde_data.setup(self.config)
        self.init_model()

        self.X = torch.tensor(self.pde_data.get_coords(),
            dtype=torch.float32, device=self.device)
        self.train_X, self.train_Y = self.pde_data.get_train_data()
        self.train_X = torch.tensor(self.train_X,
            dtype=torch.float32, device=self.device)
        self.train_Y = torch.tensor(self.train_Y,
            dtype=torch.float32, device=self.device)
        self.train_X_b, self.train_Y_b = self.pde_data.get_boundary_data()
        self.train_X_b = torch.tensor(self.train_X_b,
            dtype=torch.float32, device=self.device)
        self.train_Y_b = torch.tensor(self.train_Y_b,
            dtype=torch.float32, device=self.device)

        # Note: test_X, test_Y will numpy arrays instead of tensors.
        self.test_X, self.test_Y = self.pde_data.get_test_data()

        print("Finish setting up the runner.")
    
    def update_precond(self):
        self.pde_data.update_precond()

    def train(self):
        self.model_u.train()
        self.model_a.train()
        self.optimizer.zero_grad()

        loss_val = 0.0

        if self.epoch > self.config["Cold Start Iterations"]:
            pred_u = self.model_u(self.X)
            pred_a = self.model_a(self.X)

            # Set data
            self.pde_data.set_u_data(pred_u)
            self.pde_data.set_a_data(pred_a)

            # Compute loss & grads
            loss_val_, grad_u, grad_a = self.pde_data.compute_loss_grad()
            loss_val += loss_val_

            # Convert grad_u, grad_a to tensors
            grad_u = torch.tensor(grad_u,
                dtype=torch.float32, device=self.device)
            grad_a = torch.tensor(grad_a,
                dtype=torch.float32, device=self.device)

            # Vector-Jacobian product r^{T} @ (\partial x / \partial w)
            pred_u.backward(grad_u)
            pred_a.backward(grad_a)

        # Data loss
        pred_u = self.model_u(self.train_X)
        loss = torch.mean((pred_u - self.train_Y)**2)
        loss.backward()
        loss_val += loss.item()

        # Boundary loss
        pred_a = self.model_a(self.train_X_b)
        loss = torch.mean((pred_a - self.train_Y_b)**2)
        loss.backward()
        loss_val += loss.item()

        self.optimizer.step()

        return loss_val

    def test(self):
        self.model_u.eval()
        self.model_a.eval()
        X = torch.tensor(self.test_X,
            dtype=torch.float32, device=self.device)
        pred_Y_u = self.model_u(X).cpu().detach().numpy()
        pred_Y_a = self.model_a(X).cpu().detach().numpy()
        pred_Y = np.concatenate([pred_Y_u, pred_Y_a], axis=-1)

        test_Y = self.test_Y[:, 1]
        pred_Y = pred_Y[:, 1]

        mae = np.abs(pred_Y - test_Y).mean()
        mse = ((pred_Y - test_Y)**2).mean()
        l1re = mae / np.abs(test_Y).mean()
        l2re = np.sqrt(mse) / np.sqrt((test_Y**2).mean())

        return mae, mse, l1re, l2re

    def run(self):
        # Jointly train model_u and model_a
        self.optimizer = torch.optim.Adam(
            list(self.model_u.parameters()) + 
            list(self.model_a.parameters()),
            lr=self.config["Learning Rate"])
        print("Start Training...")
        self.logger.log_train_start()
        t1 = time.time()
        for self.epoch in range(self.config["Iterations"]):
            if self.config["Use Preconditioner"] and \
                self.epoch > self.config["Cold Start Iterations"] and \
                self.epoch % self.config["Preconditioner Update Interval"] == 0:
                self.update_precond()
            self.loss_val = self.train()
            if self.epoch % self.config["Log Interval"] == 0 or \
                self.epoch == self.config["Iterations"] - 1:
                mae, mse, l1re, l2re = self.test()
                self.logger.log_train(self.epoch, self.loss_val, mae, 
                    mse, l1re, l2re, time.time() - t1)
        print("Finish Training.")
        self.logger.log_train_end(time.time() - t1)
        self.logger.close()
        self.logger.save_history_data()
        self.logger.save_history_plot()
