import torch
from .cca_np import get_cca_similarity

def svd_reduction(tensor: torch.Tensor, accept_rate=0.8):
    left, diag, right = torch.svd(tensor)
    full = diag.abs().sum()
    ratio = diag.abs().cumsum(dim=0) / full
    num = torch.where(ratio < accept_rate,
                      tensor.new_ones(1, dtype=torch.long),
                      tensor.new_zeros(1, dtype=torch.long)
                      ).sum()
    return tensor @ right[:, :int(num)]


def zero_mean(tensor: torch.Tensor, dim):
    return tensor - tensor.mean(dim=dim, keepdim=True)


def _svd_cca(x, y):
    u_1, s_1, v_1 = x.svd()
    u_2, s_2, v_2 = y.svd()
    uu = u_1.t() @ u_2
    try:
        u, diag, v = (uu).svd()
    except RuntimeError as e:
        raise e
    a = v_1 @ s_1.reciprocal().diag() @ u
    b = v_2 @ s_2.reciprocal().diag() @ v
    return a, b, diag


def _cca(x, y, method):
    """
    Canonical Correlation Analysis,
    cf. Press 2011 "Cannonical Correlation Clarified by Singular Value Decomposition"
    :param x: data matrix [data, neurons]
    :param y: data matrix [data, neurons]
    :param method: computational method "svd"  or "qr"
    :return: _cca vectors for input x, _cca vectors for input y, canonical correlations
    """
    assert x.size(0) == y.size(0), f"Number of data needs to be same but {x.size(0)} and {y.size(0)}"
    assert x.size(0) >= x.size(1) and y.size(0) >= y.size(1), f"data[0] should be larger than data[1]"
    assert method in ("svd", "qr"), "Unknown method"

    return _svd_cca(x, y)


def svcca_distance(x, y, method="svd", accept_rate=0.98):
    """
    SVCCA distance proposed in Raghu et al. 2017
    :param x: data matrix [data, neurons]
    :param y: data matrix [data, neurons]
    :param method: computational method "svd" (default) or "qr"
    """
    x = svd_reduction(x, 0.8)
    y = svd_reduction(y, 0.8)
    #div = min(x.size(1), y.size(1))
    a, b, diag = _cca(x, y, method=method)

    #return diag.sum()/div

    diag = diag.abs()
    full = diag.sum()
    ratio = diag.cumsum(dim=0) / full
    num = torch.where(ratio < accept_rate,
                      ratio.new_ones(1, dtype=torch.long),
                      ratio.new_zeros(1, dtype=torch.long)
                      ).sum()

    return diag[:num+1].mean()



class CCAHook():
    def __init__(self, model, name, accept_rate=0.98, device=0):
        self.model = model
        self.name = name
        self.accept_rate = accept_rate
        _dict = {n: m for n, m in self.model.named_modules()}
        if self.name not in _dict.keys():
            raise NameError(f"No such name ({self.name}) in the model")

        self._module = _dict[self.name]
        self._hooked_value = None
        self._register_hook()

        self.device = device

    def clear(self):
        """
        clear the hooked tensor
        """
        self._hooked_value = None

    def _register_hook(self):

        def hook(_, __, output):
            if self._hooked_value is None:
                self._hooked_value = output.cpu().detach()
            else:
                self._hooked_value = torch.cat((self._hooked_value, output.cpu().detach()), dim=0)

        self._module.register_forward_hook(hook)

    def get_hooked_value(self):
        if self._hooked_value is None:
            raise RuntimeError("Please do model.forward() before CCA!")
        return self._hooked_value

    def svcca_similar(self, other):
        x = self.get_hooked_value()
        y = other.get_hooked_value()

        if x.dim() != y.dim():
            raise RuntimeError("tensor dimensions are incompatible!")

        if x.dim() == 4:
            x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
            x = torch.flatten(x, 1)
            y = torch.nn.functional.adaptive_avg_pool2d(y, (1, 1))
            y = torch.flatten(y, 1)

        x = x.to(self.device)
        y = y.to(self.device)

        x = zero_mean(x, dim=0)
        y = zero_mean(y, dim=0)
        return svcca_distance(x, y, accept_rate=self.accept_rate)

    def cca_similar_np(self, other):
        x = self.get_hooked_value()
        y = other.get_hooked_value()

        if x.dim() != y.dim():
            raise RuntimeError("tensor dimensions are incompatible!")

        if x.dim() == 4:
            x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
            x = torch.flatten(x, 1)
            y = torch.nn.functional.adaptive_avg_pool2d(y, (1, 1))
            y = torch.flatten(y, 1)

        x = x.numpy()
        y = y.numpy()

        return get_cca_similarity(x, y)['mean'][0]

    def cca_similar(self, other):
        x = self.get_hooked_value()
        y = other.get_hooked_value()

        if x.dim() != y.dim():
            raise RuntimeError("tensor dimensions are incompatible!")

        if x.dim() == 4:
            x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
            x = torch.flatten(x, 1)
            y = torch.nn.functional.adaptive_avg_pool2d(y, (1, 1))
            y = torch.flatten(y, 1)

        x = x.to(self.device)
        y = y.to(self.device)

        x = zero_mean(x, dim=0)
        y = zero_mean(y, dim=0)

        a, b, diag = _cca(x, y, 'svd')

        diag = diag.abs()
        full = diag.sum()
        ratio = diag.cumsum(dim=0) / full
        num = torch.where(ratio < self.accept_rate,
                          ratio.new_ones(1, dtype=torch.long),
                          ratio.new_zeros(1, dtype=torch.long)
                          ).sum()

        return diag[:num+1].mean()
