import time
import copy
import numpy as np
from scipy.interpolate import interp1d
import torch

from src.runner import RunnerBase
from utils.fourier_mlp import FourierMLP
from utils.utils import scipy_to_torch_sparse


class TimeSteppingRunner(RunnerBase):
    """
    Class for running the time stepping-based 
    PDE solver.
    """
    def setup(self):
        print("Start setting up the runner...")
        self.logger.setup(self.config)
        self.logger.log_config()
        self.pde_data.setup(self.config)
        if self.config["Use Preconditioner"]:
            self.logger.log_precond_time(self.pde_data.precond_time)
        if self.config["Log Precondition Quality"]:
            self.logger.log_original_cond(self.pde_data.original_cond)
            if self.config["Use Preconditioner"]:
                self.logger.log_precond_res(
                    self.pde_data.precond_error,
                    self.pde_data.after_cond)
        self.init_model()

        # Some constants
        bbox = self.config["Bounding Box"]
        self.time_start, self.time_end = bbox[-2], bbox[-1]
        n_intervals = self.config["Sub-Time Intervals"]
        self.interval_len = (self.time_end - self.time_start) / n_intervals
        self.dt = self.interval_len / self.config["Grid Size (t-direction)"]

        self.X = torch.tensor(self.pde_data.get_coords(),
            dtype=torch.float32, device=self.device)

        self.A = self.pde_data.get_matrix()
        self.b = torch.tensor(self.pde_data.get_rhs(),
            dtype=torch.float32, device=self.device)
        self.precond = self.pde_data.get_precond()
        self.K = self.pde_data.get_mass_matrix()

        if self.config["Use Sparse Solver"]:
            self.train_sub_interval = self.train_sparse
        else:
            self.train_sub_interval = self.train_dense

        if self.config["Use Sparse Solver"]:
            # Convert A, precond, K to sparse matrix
            self.A = scipy_to_torch_sparse(self.A, self.device)
            self.precond = [
                scipy_to_torch_sparse(M, self.device)
                for M in self.precond]
            self.K = scipy_to_torch_sparse(self.K, self.device)
        else:
            # Convert A, precond, K to dense matrix
            self.A = torch.tensor(self.A.todense(),
                dtype=torch.float32, device=self.device)
            self.precond = [
                torch.tensor(M.todense(), 
                    dtype=torch.float32, device=self.device)
                for M in self.precond]
            self.K = torch.tensor(self.K.todense(),
                dtype=torch.float32, device=self.device)
        
        # Initial condition
        self.u0 = self.pde_data.get_u0()
        self.u0 = torch.tensor(self.u0, 
            dtype=torch.float32, device=self.device)
        self.u0 = self.u0.unsqueeze(0)
        if self.config["Output Dimension"] > 1:
            self.u0 = self.pde_data.rearrange_u_data(
                self.u0.permute(1, 0, 2)).permute(1, 0, 2)
        if self.config["Time Derivative Order"] == 2:
            self.u0_prev = self.pde_data.get_u0_prev()
            self.u0_prev = torch.tensor(self.u0_prev,
                dtype=torch.float32, device=self.device)
            self.u0_prev = self.u0_prev.unsqueeze(0)
            if self.config["Output Dimension"] > 1:
                self.u0_prev = self.pde_data.rearrange_u_data(
                    self.u0_prev.permute(1, 0, 2)).permute(1, 0, 2)
        
        # Note: test_X, test_Y will be numpy arrays instead of tensors.
        self.test_X, self.test_Y = self.pde_data.get_test_data()
        self.u0_test_grid = self.test_Y[0]

        print("Finish setting up the runner.")

    def init_model(self):
        if self.config["Spatial Temporal Dimension"] <= 2:
            self.model = FourierMLP(
                input_dim=self.config["Spatial Dimension"],
                output_dim=self.config["Output Dimension"] * 
                    self.config["Grid Size (t-direction)"]
                ).to(self.device)
        else:
            self.model = FourierMLP(
                input_dim=self.config["Spatial Dimension"],
                output_dim=self.config["Output Dimension"] * 
                    self.config["Grid Size (t-direction)"],
                n_layers=5, n_hidden=128).to(self.device)

    def interp_prediction(self, pred_Y, test_t):
        EPSILON = 1e-8
        # Append the initial condition
        pred_Y = pred_Y.reshape(-1, 
            self.config["Grid Size (t-direction)"], 
            self.config["Output Dimension"])
        pred_Y = np.concatenate([
            self.u0_test_grid, pred_Y], axis=1)
        # The time step for the prediction
        pred_t = np.linspace(
            self.time_start + self.interval_len * self.sub_interval - EPSILON,
            self.time_start + self.interval_len * (self.sub_interval + 1) + EPSILON,
            self.config["Grid Size (t-direction)"] + 1
        )
        f = interp1d(pred_t, pred_Y, axis=1)
        pred_Y = f(test_t)

        return pred_Y

    def test_sub_interval(self):
        test_t = self.test_X[1][self.sub_interval]
        test_Y = self.test_Y[1][self.sub_interval]
        if test_Y is None:
            return None, None, None, None

        self.model.eval()
        X = torch.tensor(self.test_X[0],
            dtype=torch.float32, device=self.device)
        pred_Y = self.model(X).cpu().detach().numpy()
        pred_Y = self.interp_prediction(pred_Y, test_t)

        mae = np.abs(pred_Y - test_Y).mean()
        mse = ((pred_Y - test_Y)**2).mean()
        l1re = mae / np.abs(test_Y).mean()
        l2re = np.sqrt(mse) / np.sqrt((test_Y**2).mean())

        return mae, mse, l1re, l2re

    def test_sub_interval_final(self):
        preds = np.concatenate(self.sub_interval_pred, axis=1)
        test_Y = [u for u in self.test_Y[1] if u is not None]
        tests = np.concatenate(test_Y, axis=1)

        mae = np.abs(preds - tests).mean()
        mse = ((preds - tests)**2).mean()
        l1re = mae / np.abs(tests).mean()
        l2re = np.sqrt(mse) / np.sqrt((tests**2).mean())

        return mae, mse, l1re, l2re

    def update_tm_rhs(self, ts):
        """
        Update the time-dependent rhs `self.b` of the 
        form of `f(x, t)` for the next sub-interval.

        Note:
            `ts` is the time steps (1D numpy array) 
            for the next sub-interval.
            This function uses FeniCS to compute the rhs,
            which is not differentiable. If your rhs is
            of the form `f(x, t, u)`, you can use the
            `self.pde_data.compute_tm_rhs` function to
            compute the rhs at each training iteration.
        """
        pass

    def update_sub_interval(self):
        """
        Update the system equation for the next sub-interval.
        """
        def post_process(u):
            if self.config["Output Dimension"] > 1:
                u = self.pde_data.rearrange_u_data(u)
            u = u.permute(1, 0, 2)
            return u

        self.model.eval()
        pred_Y = self.model(self.X).reshape(-1,
            self.config["Grid Size (t-direction)"],
            self.config["Output Dimension"])
        if self.config["Time Derivative Order"] == 2:
            if self.config["Grid Size (t-direction)"] == 1:
                self.u0_prev = self.u0
            else:
                self.u0_prev = pred_Y[:, -2:-1, :].detach()
                self.u0_prev = post_process(self.u0_prev)
        self.u0 = pred_Y[:, -1:, :].detach()
        self.u0 = post_process(self.u0)
        # Update u0 on the test grid
        X = torch.tensor(self.test_X[0],
            dtype=torch.float32, device=self.device)
        pred_Y = self.model(X).cpu().detach().numpy().reshape(-1,
            self.config["Grid Size (t-direction)"],
            self.config["Output Dimension"])
        self.u0_test_grid = pred_Y[:, -1:, :]
        # Update the time-dependent bc
        ts = np.linspace(
            self.time_start + self.interval_len * (self.sub_interval + 1) + self.dt,
            self.time_start + self.interval_len * (self.sub_interval + 2),
            self.config["Grid Size (t-direction)"]
        )
        self.update_tm_rhs(ts)

    def compute_residual(self):
        raw_pred_Y = self.model(self.X)
        raw_pred_Y = raw_pred_Y.reshape(-1,
            self.config["Grid Size (t-direction)"],
            self.config["Output Dimension"])
        pred_Y = raw_pred_Y
        if self.config["Output Dimension"] > 1:
            pred_Y = self.pde_data.rearrange_u_data(pred_Y)
        pred_Y = pred_Y.permute(1, 0, 2)

        # Compute rhs for the each time
        if self.config["Time Derivative Order"] == 1:
            # Two-point BDF (Backward difference formula)
            # u_t \approx (u_{t} - u_{t-1}) / dt
            # The rhs would be u_{t-1}
            b = torch.concat([self.u0, pred_Y[:-1]], dim=0)
        elif self.config["Time Derivative Order"] == 2:
            # Three-point BDF (Backward difference formula)
            # u_tt \approx (u_{t} - 2u_{t-1} + u_{t-2}) / dt^2
            # The rhs would be 2u_{t-1} - u_{t-2}
            b = torch.concat([self.u0, pred_Y[:-1]], dim=0)
            b = 2 * b - torch.concat([self.u0_prev, b[:-1]], dim=0)
        else:
            raise NotImplementedError("Only support 1st and 2nd order.")
        # Time-dependent source term
        t = torch.linspace(
            self.time_start + self.interval_len * self.sub_interval + self.dt,
            self.time_start + self.interval_len * (self.sub_interval + 1),
            self.config["Grid Size (t-direction)"],
            device=self.device
        )
        source_term = self.pde_data.compute_tm_rhs(
            self.X, t, self.dt, raw_pred_Y.permute(1, 0, 2))
        if source_term is not None:
            if self.config["Output Dimension"] > 1:
                source_term = self.pde_data.rearrange_u_data(
                    source_term.permute(1, 0, 2)).permute(1, 0, 2)
            b += source_term
        # Apply mass matrix
        b = torch.matmul(self.K, b)
        # Apply boundary condition
        if len(self.b.shape) != 0:
            b += self.b

        r = torch.matmul(self.A, pred_Y) - b

        return r, pred_Y, b

    def loss_fn(self):
        # Get residual
        r, _, _ = self.compute_residual()

        if self.config["Use Preconditioner"]:
            # Preconditioning
            Pr, L, U = self.precond
            # r := Pr @ r
            r = torch.matmul(Pr, r)
            # r := L^{-1} @ r
            r = r.squeeze(-1).permute(1, 0)
            r = torch.linalg.solve_triangular(L, r, upper=False)
            # r := U^{-1} @ r
            r = torch.linalg.solve_triangular(U, r, upper=True)
            r = r.permute(1, 0)
        else:
            r = r.squeeze(-1)
        # r := 1/n * ||r||^2
        r = torch.mean(r ** 2, dim=1)

        # r := sum_i w_i * r_i
        EPS = 1e2
        w = torch.ones_like(r)
        w[1:] = torch.cumsum(r, dim=0)[:-1]
        w[1:] = torch.exp(-EPS * w[1:])

        r = torch.dot(w, r)

        return r
    
    def train_sparse(self):
        self.model.train()
        self.optimizer.zero_grad()

        if self.config["Use Preconditioner"]:
            Pr, L, U = self.precond

        # [Compute r := A @ x - b]
        r, pred_Y, b = self.compute_residual()

        with torch.no_grad():
            # [Annealing]
            EPS = 1e2
            w = torch.ones((r.shape[0]), device=self.device)
            r_val = torch.mean(r.squeeze(-1)**2, dim=1)
            w[1:] = torch.cumsum(r_val, dim=0)[:-1]
            w[1:] = torch.exp(-EPS * w[1:])
            w = torch.sqrt(w).reshape(1, -1)

            # [Loss value: 1/n * ||r||^2]
            loss_val = torch.mean(r**2).item()

            r = r.squeeze(-1).permute(1, 0)

            if self.config["Use Preconditioner"]:
                # [Compute r := P^{-1} @ r]
                # r := Pr @ r
                r = torch.sparse.mm(Pr, r)
                # r := L^{-1} @ r
                r = torch.triangular_solve(r, L, upper=False).solution
                # r := U^{-1} @ r
                r = torch.triangular_solve(r, U, upper=True).solution

                # [Compute r := A^{T} @ (P^{-1})^{T} @ r]
                # r := U^{-1}^{T} @ r
                r = torch.triangular_solve(r, U, upper=True, transpose=True).solution
                # r := L^{-1}^{T} @ r
                r = torch.triangular_solve(r, L, upper=False, transpose=True).solution
                # r := Pr^{T} @ r
                r = torch.sparse.mm(Pr.t(), r)
            # grad_x := (A^{T} @ r) / n
            grad_x = torch.sparse.mm(self.A.t(), r) / r.shape[0]
            # grad_b := -r / n
            grad_b = -r / r.shape[0]

        # Vector-Jacobian product grad_x^{T} @ (\partial x / \partial w)
        pred_Y = pred_Y.squeeze(-1).permute(1, 0)
        pred_Y.backward(grad_x * w, retain_graph=True)
        # Vector-Jacobian product grad_b^{T} @ (\partial b / \partial w)
        b = b.squeeze(-1).permute(1, 0)
        b.backward(grad_b * w)

        # Gradient clipping
        if self.cur_attempt > 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1e-3)

        self.optimizer.step()

        return loss_val

    def record_sub_interval_pred(self):
        """
        Record the prediction for the current sub-interval.
        """
        test_t = self.test_X[1][self.sub_interval]
        if test_t is None:
            return

        self.model.eval()
        X = torch.tensor(self.test_X[0],
            dtype=torch.float32, device=self.device)
        pred_Y = self.model(X).cpu().detach().numpy()
        pred_Y = self.interp_prediction(pred_Y, test_t)

        self.sub_interval_pred.append(pred_Y)

    def save_model_state(self):
        """
        Temporarily save the model state
        in the memory (CPU).
        """
        # Remember to use deepcopy
        self.model_state = copy.deepcopy(self.model.state_dict())
    
    def load_model_state(self):
        """
        Load the temporarily saved model state.
        """
        self.model.load_state_dict(self.model_state)

    def run(self):
        print("Start Training...")
        self.logger.log_train_start()
        t1 = time.time()
        self.sub_interval_pred = []
        # In case the loss value explodes
        EXPLODE_THRESHOLD = 1e8
        MAX_ATTEMPTS = 1
        for self.sub_interval in range(self.config["Sub-Time Intervals"]):
            self.logger.log_sub_interval_start(self.sub_interval)
            t2 = time.time()

            n_iters = self.config["Iterations"]
            if self.sub_interval == 0 and \
                self.config["Cold Start Iterations"] > 0:
                n_iters = self.config["Cold Start Iterations"]
            
            # Save the model state
            self.save_model_state()

            # Train the model
            for cur_attempt in range(MAX_ATTEMPTS):
                self.cur_attempt = cur_attempt
                # Initialize the optimizer
                init_loss_val = None
                self.optimizer = torch.optim.Adam(
                    self.model.parameters(), 
                    lr=self.config["Learning Rate"])
                # Train the model
                for epoch in range(n_iters):
                    loss_val = self.train_sub_interval()
                    if init_loss_val is None:
                        init_loss_val = loss_val
                    if epoch % self.config["Log Interval"] == 0 or \
                        epoch == n_iters - 1:
                        mae, mse, l1re, l2re = self.test_sub_interval()
                        if mae is None:
                            self.logger.log_train_empty_sub_interval(epoch, 
                                loss_val, time.time() - t1)
                        else:
                            self.logger.log_train(epoch, loss_val, mae, 
                                mse, l1re, l2re, time.time() - t1)

                if loss_val < EXPLODE_THRESHOLD * init_loss_val:
                    break
                
                if cur_attempt < MAX_ATTEMPTS - 1:
                    # Log the explode warning
                    self.logger.log_warning("Loss value explodes. " + \
                        "Reset the model and optimizer.")
                    # Log the attempt
                    self.logger.log_warning("Attempt {}/{}".format(
                        cur_attempt + 1, MAX_ATTEMPTS))
                else:
                    # Log the explode and max attempt
                    self.logger.log_error("Loss value explodes. " + \
                        "Max attempt reached.")
                    # Log the skip
                    self.logger.log_error("Skip the training of this sub-interval.")
                # Reset the model
                self.load_model_state()

            self.record_sub_interval_pred()
            self.update_sub_interval()
            self.logger.log_sub_interval_end(time.time() - t2)

        print("Finish Training.")
        res = self.test_sub_interval_final()
        self.logger.log_sub_interval_final_result(*res)
        self.logger.log_train_end(time.time() - t1)
        self.logger.close()
        test_Y = [u for u in self.test_Y[1] if u is not None]
        self.logger.save_sub_interval_prediction(
            np.concatenate(test_Y, axis=1),
            np.concatenate(self.sub_interval_pred, axis=1)
        )
        # Only save the model in the last sub-interval
        self.logger.save_model(self.model)
