import numpy as np
import torch
from xmeta.utils.higher_grad import higher_grad
from scipy.linalg import block_diag
import copy


def _entropy_weight(vec: np.ndarray):
    # min_val = vec.min()
    # assert min_val >= 0, f'vec: {vec}'
    # if min_val <= 0:
    #     # avoid negative values raised by float rounding
    #     small_val = vec.max() * 1e-8
    #     # assert np.abs(min_val) < small_val
    #     vec = np.maximum(vec, small_val)
    vec = np.maximum(vec, 0.)
    normalization = vec.sum(axis=0)
    normalization = np.where(normalization <= 0, np.inf, normalization)
    vec = vec / normalization
    # [Note] 0 <= vi <= 1/len(vi) for each col vi=vec[:, i]
    entropy = - np.nan_to_num(vec * np.log(vec), nan=0)
    entropy = np.where(vec.max(axis=0) <= 0, np.inf, entropy.sum(axis=0))
    assert not np.isnan(entropy).any(), 'some entropy values are nan'
    return entropy


class FactorizedMatrix:
    def __init__(self, v0: np.ndarray, m: np.ndarray, v1: np.ndarray):
        assert v0.ndim == 1 or v0.ndim == 2
        assert m.ndim == 2
        assert v1.ndim == 1 or v1.ndim == 2
        if v0.ndim == 1:
            assert m.shape[0] == 1
            v0 = v0[:, None]
        else:
            assert v0.shape[-1] == m.shape[0]
        if v1.ndim == 1:
            assert m.shape[1] == 1
            v1 = v1[None, :]
        else:
            assert v1.shape[0] == m.shape[1]
        self.col_vecs = None
        self.row_vecs = None
        self._set_factors(v0, m, v1)

        #
        # assert (self.col_vecs * self.row_vecs.T).min() > 0, \
        #     f'{u0 - u1.T}, {(np.abs(u0 - u1.T) > 0.001).sum()}, {(np.abs(u0 - u1.T) <= 0.001).sum()}'\
        #     f'{m - m.T}, {(np.abs(m - m.T) > 0.001).sum()}, {(np.abs(m - m.T) <= 0.001).sum()}'
        #
    
    def _set_factors(self, v0: np.ndarray, m: np.ndarray, v1: np.ndarray):
        u0, s, u1 = np.linalg.svd(m, full_matrices=True)
        self.col_vecs = v0.dot(u0)
        self.row_vecs = (np.diag(s).dot(u1)).dot(v1)
    
    def dot(self, v: np.ndarray):
        assert v.ndim == 1 or v.ndim == 2
        assert v.shape[0] == self.row_vecs.shape[-1]
        if v.ndim == 1:
            v = v[:, None]
        
        return self.col_vecs.dot(self.row_vecs.dot(v)).squeeze()
    
    def add(self, mat):
        assert mat.col_vecs.shape[-1] == mat.row_vecs.shape[0]
        self.col_vecs = np.concatenate([self.col_vecs, mat.col_vecs], axis=1)
        self.row_vecs = np.concatenate([self.row_vecs, mat.row_vecs], axis=0)
        
    def transpose(self):
        self.col_vecs, self.row_vecs = self.row_vecs.T, self.col_vecs.T

    def delete_elements(self, indexes):
        self.col_vecs = np.delete(self.col_vecs, indexes, axis=1)
        self.row_vecs = np.delete(self.row_vecs, indexes, axis=0)

    def scale(self, s):
        self.col_vecs = s * self.col_vecs

    def __len__(self):
        return self.col_vecs.shape[0]
    
    @property
    def shape(self):
        return (self.col_vecs.shape[0], self.row_vecs.shape[-1])
    
    @property
    def num_elements(self):
        return self.col_vecs.shape[-1]


def inverse_psdmat(m, min_eigen_ratio: float = 0.01):
    assert isinstance(m, PSDMatrix)
    m = copy.deepcopy(m)
    m.orthogonalize()
    col_norms = np.linalg.norm(m.col_vecs, axis=0)
    row_norms = np.linalg.norm(m.row_vecs, axis=1)
    eigenvalues = col_norms * row_norms
    avg_ev = eigenvalues[eigenvalues > 0].mean()
    min_ev = avg_ev * min_eigen_ratio
    inv_eigenvalues = np.where(eigenvalues >= min_ev,
                               1 / np.maximum(eigenvalues, min_ev), 0.)
    scales = np.sqrt(inv_eigenvalues)

    m.col_vecs = m.col_vecs * scales[None, :] / col_norms[None, :]
    m.row_vecs = m.row_vecs * scales[:, None] / row_norms[:, None]

    return m


class PSDMatrix(FactorizedMatrix):

    def __init__(self, force_psd: bool = False, **kwargs):
        self.force_psd = force_psd
        super().__init__(**kwargs)

    def _set_factors(self, v0: np.ndarray, m: np.ndarray, v1: np.ndarray):
        u0, s, u1 = np.linalg.svd(m, full_matrices=True)
        if self.force_psd:
            assert m.shape[0] == m.shape[1]
            s = np.maximum(s, 0.)
            u0 = (u0 + u1.T) / 2.
            u1 = u0.T
        else:
            assert s.min() >= 0.
        
        d_ss = np.diag(np.sqrt(s))
        self.col_vecs = v0.dot(u0.dot(d_ss))
        self.row_vecs = (d_ss.dot(u1)).dot(v1)
        # assert (self.col_vecs * self.row_vecs.T).min() > 0, f'{(self.col_vecs * self.row_vecs.T).min()}

    def orthogonalize(self, output=False):
        m = self.row_vecs.dot(self.col_vecs)
        u, s, v = np.linalg.svd(m, full_matrices=True, hermitian=True)
        self.row_vecs = ((u.real).T).dot(self.row_vecs)
        self.col_vecs = self.col_vecs.dot((v.real).T)

        if output:
            return u, s, v


class CrossEntropyHessian(PSDMatrix):

    def __init__(self, model_output: torch.Tensor, parameters: torch.Tensor,
                 max_num_elements: int = None,
                 ortho_vectors: bool = False,
                 entropy_weight: bool = False):
        assert model_output.dim() == 2
        n_data = len(model_output)
        assert n_data > 0
        self.n_input = model_output.shape[0] * model_output.shape[1]

        self.max_num_elements = max_num_elements
        self.trace_elements = None
        self.ortho_vectors = ortho_vectors
        self.entropy_weight = entropy_weight

        v1 = higher_grad(model_output, parameters,
                         retain_graph=True, create_graph=False, to_cpu=True
                         ).detach().to('cpu').numpy()
        v1 = v1.reshape(-1, v1.shape[-1])

        prob_vecs = torch.nn.Softmax(dim=1)(model_output).detach().to('cpu').numpy()
        blocks = []
        for vec in prob_vecs:
            mat = vec[None, :]
            blocks.append((mat.T).dot(mat))

        m = (np.diag(prob_vecs.flatten()) - block_diag(*blocks)) / n_data
        super().__init__(v0=v1.T, m=m, v1=v1, force_psd=True)

        if self.max_num_elements is not None:
            self.calc_trace_elements()
            if self.num_elements > self.max_num_elements:
                self.delete_small_elements(self.num_elements - self.max_num_elements)

    def calc_trace_elements(self):
        vecs = (self.col_vecs * self.row_vecs.T)
        self.trace_elements = vecs.sum(axis=0)
        if self.entropy_weight:
            self.trace_elements = np.nan_to_num(
                self.trace_elements * (- _entropy_weight(vecs)), nan=-np.inf)
            # assert not np.isnan(self.trace_elements).any(), 'some trace elements are nan'

    def delete_small_elements(self, n_del: int):
        assert self.trace_elements is not None
        indexes = self.trace_elements.argsort()[: n_del]
        self.trace_elements = np.delete(self.trace_elements, indexes, axis=0)
        self.delete_elements(indexes)
    
    def add(self, mat):
        super().add(mat)
        self.n_input = self.n_input + mat.n_input

        if self.ortho_vectors:
            _, s, _ = self.orthogonalize(output=True)
            self.trace_elements = s
        elif self.trace_elements is not None:
            if mat.trace_elements is None:
                mat.calc_trace_elements()
            self.trace_elements = np.append(self.trace_elements, mat.trace_elements)
            assert len(self.trace_elements) ==\
                self.col_vecs.shape[-1] == self.row_vecs.shape[0],\
                (f'{len(self.trace_elements)} == {self.col_vecs.shape[-1]}'
                 ' == {self.row_vecs.shape[0]} not satisfied')
        if self.num_elements > self.max_num_elements:
            self.delete_small_elements(self.num_elements - self.max_num_elements)

    def normalize(self):
        self.scale(self.n_input / len(self.trace_elements))
