from typing import Union

import numpy as np

from lib.utils.numpy_metrics import masked_mse, masked_mae
from .imputer import Imputer


class MatrixFactorization(Imputer):
    """Matrix Factorization main class."""

    SEED: int = 2022

    def __init__(self, rank: int,
                 lambda_u: float = 5000,
                 lambda_v: float = 5000,
                 max_iter: int = 100,
                 verbose=True):
        super(MatrixFactorization, self).__init__(is_deterministic=False,
                                                  in_sample=True)
        self.rank = rank
        # hyperparams
        self.lambda_u = lambda_u
        self.lambda_v = lambda_v
        # params
        self.U = None  # nodes matrix
        self.V = None  # steps matrix
        # private attrs
        self.verbose = verbose
        self.random = np.random.default_rng(self.SEED)
        self.max_iter = max_iter
        self.val_perc = 0.1
        self.patience = None
        self._opt_mae = np.inf

    def _weights_initializer(self, *shape):
        return self.random.random(shape).astype(np.float32)

    def reset_weights(self, input_shape):
        n_steps, n_nodes = input_shape
        self.random = np.random.default_rng(self.SEED)
        self.U = self._weights_initializer(n_nodes, self.rank)
        self.V = self._weights_initializer(n_steps, self.rank)

    def params(self):
        return dict(rank=self.rank,
                    lambda_u=self.lambda_u,
                    lambda_v=self.lambda_v)

    def update_u(self, x):
        raise NotImplementedError()

    def update_v(self, x):
        raise NotImplementedError()

    def update_auxiliary(self, x_train):
        pass

    def predict(self, x, mask):
        # create validation set
        eval_mask = np.random.default_rng(8).random(x.shape) < self.val_perc
        eval_mask = (mask & eval_mask).astype(int)
        # train params
        self.reset_weights(x.shape)
        train_mask = mask - eval_mask
        x_train = np.where(train_mask, x, np.nan)
        for step in range(self.max_iter):
            # Update spatial matrix U
            self.update_u(x_train)
            # Update temporal matrix U
            self.update_v(x_train)
            # Update further params
            self.update_auxiliary(x_train)
            # Evaluation
            x_hat = self.x_hat()
            loss = self.loss(x_hat, x, train_mask)
            train_mae = masked_mae(x_hat, x, train_mask)
            train_mse = masked_mse(x_hat, x, train_mask) * train_mask.sum()
            val_mae = masked_mae(x_hat, x, eval_mask)
            self._log_step(step, train_loss=loss,
                           train_mse=train_mse,
                           train_mae=train_mae, val_mae=val_mae)
            if self.patience is not None:
                if val_mae - self._opt_mae < -0.01:
                    self._save_best_model(step, val_mae)
                elif step >= self._opt_step + self.patience:
                    break
        # predict
        if self.patience is not None:
            self._load_best_model()
        x_hat = np.where(train_mask, x, self.x_hat())
        x_hat = np.minimum(x_hat, np.percentile(x, 96))
        return x_hat

    def _save_best_model(self, step, mae):
        self._opt_U = self.U.copy()
        self._opt_V = self.V.copy()
        self._opt_mae = mae
        self._opt_step = step

    def _load_best_model(self):
        self.U = self._opt_U
        self.V = self._opt_V

    def _log_step(self, step, **metrics):
        if self.verbose:
            log = [f"Iteration n. {step}/{self.max_iter}"]
            for name, value in metrics.items():
                log.append(f"{name}: {value:.4f}")
            log = "  -  ".join(log)
            print(log)

    # Inference

    def x_hat(self):
        return self.V @ self.U.T

    def loss(self, x_hat, x, mask=None):
        if mask is None:
            mask = ~np.isnan(x).astype(int)
        loss = masked_mse(x_hat, x, mask) * mask.sum()
        return loss


class NNMatrixFactorization(MatrixFactorization):

    def update_u(self, x):
        _, n_nodes = x.shape
        for i in range(n_nodes):
            valid_idxs = np.where(~np.isnan(x[:, i]))[0]
            x_i = x[valid_idxs, i]
            V = self.V[valid_idxs, :]
            num = x_i @ V
            den = self.U[i] @ V.T @ V
            self.U[i] *= np.divide(num, den)

    def update_v(self, x):
        n_steps, _ = x.shape
        for t in range(n_steps):
            valid_idxs = np.where(~np.isnan(x[t]))[0]
            if len(valid_idxs):
                x_t = x[t, valid_idxs]
                U = self.U[valid_idxs]
                num = x_t @ U
                den = self.V[t] @ U.T @ U
                self.V[t] *= np.divide(num, den)
            else:
                self.V[t] = self.V[t - 1].copy()


class GRMF(NNMatrixFactorization):
    """Graph Regularized Matrix Factorization (GRMF)."""

    short_name = 'grmf'
    SEED: int = 2022

    def __init__(self, rank: int, adj: np.ndarray,
                 **kwargs):
        super(GRMF, self).__init__(rank, **kwargs)
        self.adj = adj

    @property
    def D(self):
        return np.diag(np.sum(self.adj, axis=1))

    @property
    def L(self):
        return self.D - self.adj

    def update_u(self, x):
        _, n_nodes = x.shape
        for i in range(n_nodes):
            valid_idxs = np.where(~np.isnan(x[:, i]))[0]
            x_i = x[valid_idxs, i]
            V = self.V[valid_idxs, :]
            num = x_i @ V + self.lambda_u * self.adj[:, i] @ self.U
            den = self.U[i] @ V.T @ V + self.lambda_u * self.D[i] @ self.U
            self.U[i] *= np.divide(num, den)

    # Inference

    def graph_reg(self):
        return np.trace(self.U.T @ self.L @ self.U)

    def loss(self, x_hat, x, mask=None):
        loss = super(GRMF, self).loss(x_hat, x, mask)
        loss += self.lambda_u * self.graph_reg()
        return loss


class TRMF(MatrixFactorization):
    """Temporal Regularized Matrix Factorization (TRMF)."""

    short_name = 'trmf'
    SEED: int = 2022

    def __init__(self, rank: int, lags: Union[set, np.ndarray],
                 lambda_theta: float = 5000,
                 eta: float = 0.03,
                 **kwargs):
        super(TRMF, self).__init__(rank, **kwargs)
        self.lags = np.asarray(sorted(lags))
        # hyperparams
        self.lambda_theta = lambda_theta
        self.eta = eta
        # params
        self.theta = None  # AR params matrix

    def reset_weights(self, input_shape):
        super(TRMF, self).reset_weights(input_shape)
        self.theta = self._weights_initializer(len(self.lags), self.rank)

    def params(self):
        params = super(TRMF, self).params()
        params.update(lags=self.lags,
                      lambda_theta=self.lambda_theta,
                      eta=self.eta)
        return params

    def update_u(self, x):
        """Based on
        https://nbviewer.org/github/xinychen/transdim/blob/master/experiments/Imputation-TRMF.ipynb"""
        _, n_nodes = x.shape
        for i in range(n_nodes):
            valid_idxs = np.where(~np.isnan(x[:, i]))[0]  # idxs with no missing
            V = self.V[valid_idxs, :]
            vec0 = V.T @ x[valid_idxs, i]
            mat0 = np.linalg.inv(V.T @ V + self.lambda_u * np.eye(self.rank))
            self.U[i] = mat0 @ vec0

    def update_v(self, x):
        """Based on
        https://nbviewer.org/github/xinychen/transdim/blob/master/experiments/Imputation-TRMF.ipynb"""
        n_steps, _ = x.shape
        for t in range(n_steps):
            valid_idxs = np.where(~np.isnan(x[t]))[0]  # idxs with no missing
            Wt = self.U[valid_idxs]
            Mt = np.zeros((self.rank, self.rank))
            Nt = np.zeros(self.rank)
            if t < self.lags[-1]:  # max lag
                Pt = np.zeros((self.rank, self.rank))
                Qt = np.zeros(self.rank)
            else:
                Pt = np.eye(self.rank)
                Qt = np.einsum('ij, ij -> j', self.theta, self.V[t - self.lags])
            if t < n_steps - self.lags[0]:  # min lag
                if self.lags[-1] <= t < n_steps - self.lags[-1]:
                    idx = np.arange(len(self.lags))
                else:
                    idx = np.where((self.lags[-1] <= t + self.lags) &
                                   (t + self.lags < n_steps))[0]
                for k in idx:
                    Ak = self.theta[k]
                    Mt += np.diag(Ak ** 2)
                    theta0 = self.theta.copy()
                    theta0[k, :] = 0
                    update = np.einsum('ij, ij -> j',
                                       theta0,
                                       self.V[t + self.lags[k] - self.lags])
                    Nt += np.multiply(Ak, self.V[t + self.lags[k]] - update)
            up_vec = self.lambda_v * (Nt + Qt)
            vec0 = Wt.T @ x[t, valid_idxs] + up_vec
            up_mat = self.lambda_v * (Mt + Pt + self.eta * np.eye(self.rank))
            mat0 = np.linalg.inv(Wt.T @ Wt + up_mat)
            self.V[t, :] = mat0 @ vec0

    def update_auxiliary(self, x_train):
        self.update_theta()

    def update_theta(self):
        n_steps, _ = self.V.shape
        for k in range(len(self.lags)):
            theta0 = self.theta.copy()
            theta0[k, :] = 0
            mat0 = np.zeros((n_steps - self.lags[-1], self.rank))
            for l in range(len(self.lags)):
                beg, end = self.lags[-1] - self.lags[l], n_steps - self.lags[l]
                mat0 += self.V[beg:end] @ np.diag(theta0[l])
            VarPi = self.V[self.lags[-1]:n_steps] - mat0
            var1 = np.zeros((self.rank, self.rank))
            var2 = np.zeros(self.rank)
            for t in range(self.lags[-1], n_steps):
                B = self.V[t - self.lags[k], :]
                var1 += np.diag(np.multiply(B, B))
                var2 += np.diag(B) @ VarPi[t - self.lags[-1], :]
            var1 += self.lambda_theta * np.eye(self.rank) / self.lambda_v
            self.theta[k, :] = np.linalg.inv(var1) @ var2

    def _save_best_model(self, step, mae):
        super(TRMF, self)._save_best_model(step, mae)
        self._opt_theta = self.theta.copy()

    def _load_best_model(self):
        super(TRMF, self)._load_best_model()
        self.theta = self._opt_theta
