import math

import torch
from torch import nn
import torch.nn.functional as F

from .operation import Operation
from .operation import ALL_OPS
from ..utils import scatter_mean

KFAC_EXPAND = 'expand'
KFAC_REDUCE = 'reduce'


class Linear(Operation):
    """
    module.weight: f_out x f_in
    module.bias: f_out x 1

    Argument shapes
    in_data: n x f_in
    out_grads: n x f_out
    """
    supported_operations = set(ALL_OPS)

    @staticmethod
    def preprocess_in_data(module, in_data, out_data):
        if in_data.ndim > 2 and hasattr(module, 'n_graph'):
            raise ValueError('GNN layer in_data should not have more than two dimensions.')
        if in_data.ndim > 2:
            if module.kfac == KFAC_EXPAND:
                # Assumes that the first dimension is the mini-batch dimension.
                sqrt_scale = math.sqrt(in_data[0, ..., 0].numel())
                # n x * x f_in -> n* x f_in
                in_data = in_data.flatten(end_dim=in_data.ndim - 2) / sqrt_scale
            elif module.kfac == KFAC_REDUCE:
                # Assumes that the first dimension is the mini-batch dimension.
                weight_sharing_dims = tuple(range(1, in_data.ndim - 1))
                # n x * x f_in -> n x f_in
                in_data = in_data.mean(weight_sharing_dims)
        # Check if module is GNN layer.
        if hasattr(module, 'n_graph'):
            current_batch_size = in_data.shape[0]
            n_graph = module.n_graph
            if module.kfac == KFAC_EXPAND and current_batch_size != n_graph:
                # This is just one way of scaling; final scaling will be 1 / \sum_{n=1}^N R_n.
                scale = current_batch_size / n_graph
                in_data = in_data / math.sqrt(scale)
            elif module.kfac == KFAC_REDUCE and current_batch_size != n_graph:
                # This is just one way of scaling; in_data is scaled by 1 / \sqrt(R_n).
                graph_idx = torch.arange(n_graph, device=in_data.device)
                n_node = module.n_node
                n_edge = module.n_edge
                if current_batch_size == n_node.sum():
                    node_gr_idx = torch.repeat_interleave(graph_idx, n_node, dim=0)
                    in_data = scatter_mean(
                        in_data, node_gr_idx, dim=0, dim_size=n_graph, sqrt=True)
                elif current_batch_size == n_edge.sum():
                    edge_gr_idx = torch.repeat_interleave(graph_idx, n_edge, dim=0)
                    in_data = scatter_mean(
                        in_data, edge_gr_idx, dim=0, dim_size=n_graph, sqrt=True)
                else:
                    raise ValueError(
                        f'Something is wrong with current batch of size {current_batch_size}.')
        return in_data

    @staticmethod
    def extend_in_data(in_data):
        # Extend in_data with ones.
        # linear: n x f_in
        #      -> n x (f_in + 1)
        shape = list(in_data.shape)
        shape[1] = 1
        ones = in_data.new_ones(shape)
        return torch.cat((in_data, ones), dim=1)

    @staticmethod
    def preprocess_out_grads(module, out_grads):
        if out_grads.ndim > 2 and hasattr(module, 'n_graph'):
            raise ValueError('GNN layer out_grads should not have more than two dimensions.')
        if out_grads.ndim > 2:
            if module.kfac == KFAC_EXPAND:
                # Assumes that the first dimension is the mini-batch dimension.
                sqrt_scale = math.sqrt(out_grads[0, ..., 0].numel())
                # n x * x f_out -> n* x f_out
                out_grads = out_grads.flatten(end_dim=out_grads.ndim - 2) / sqrt_scale
            elif module.kfac == KFAC_REDUCE:
                # Assumes that the first dimension is the mini-batch dimension.
                weight_sharing_dims = tuple(range(1, out_grads.ndim - 1))
                # n x * x f_out -> n x f_out
                out_grads = out_grads.sum(weight_sharing_dims)
        # Check if module is GNN layer.
        if hasattr(module, 'n_graph'):
            current_batch_size = out_grads.shape[0]
            n_graph = module.n_graph
            if module.kfac == KFAC_REDUCE and current_batch_size != n_graph:
                # This is just one way of scaling; out_grads is scaled by 1 / sqrt(R_n).
                graph_idx = torch.arange(n_graph, device=out_grads.device)
                n_node = module.n_node
                n_edge = module.n_edge
                if current_batch_size == n_node.sum():
                    node_gr_idx = torch.repeat_interleave(graph_idx, n_node, dim=0)
                    out_grads = scatter_mean(
                        out_grads, node_gr_idx, dim=0, dim_size=n_graph, sqrt=True)
                elif current_batch_size == n_edge.sum():
                    edge_gr_idx = torch.repeat_interleave(graph_idx, n_edge, dim=0)
                    out_grads = scatter_mean(
                        out_grads, edge_gr_idx, dim=0, dim_size=n_graph, sqrt=True)
                else:
                    raise ValueError(
                        f'Something is wrong with current batch of size {current_batch_size}.')
        return out_grads

    @staticmethod
    def batch_grads_weight(
        module: nn.Module, in_data: torch.Tensor, out_grads: torch.Tensor
    ):
        return torch.bmm(
            out_grads.unsqueeze(2), in_data.unsqueeze(1)
        )  # n x f_out x f_in

    @staticmethod
    def batch_grads_bias(module, out_grads):
        return out_grads

    @staticmethod
    def grad_weight(module: nn.Module, in_data: torch.Tensor, out_grads: torch.Tensor):
        return torch.matmul(out_grads.T, in_data)

    @staticmethod
    def grad_bias(module: nn.Module, out_grads: torch.Tensor):
        return torch.sum(out_grads, dim=0)

    @staticmethod
    def cov_diag_weight(module, in_data, out_grads):
        in_in = in_data.mul(in_data)  # n x f_in
        grad_grad = out_grads.mul(out_grads)  # n x f_out
        return torch.matmul(grad_grad.T, in_in)  # f_out x f_in

    @staticmethod
    def cov_diag_bias(module, out_grads):
        grad_grad = out_grads.mul(out_grads)  # n x f_out
        return grad_grad.sum(dim=0)  # f_out x 1

    @staticmethod
    def cov_kron_A(module, in_data):
        return torch.matmul(in_data.T, in_data)  # f_in x f_in

    @classmethod
    def cov_swift_kron_A(cls, module, in_data):
        n, f_in = in_data.shape
        if n < f_in:
            return in_data  # n x f_in
        else:
            return cls.cov_kron_A(module, in_data)  # f_in x f_in

    @staticmethod
    def cov_kron_B(module, out_grads):
        return torch.matmul(out_grads.T, out_grads)  # f_out x f_out

    @classmethod
    def cov_swift_kron_B(cls, module, out_grads):
        n, f_out = out_grads.shape
        if n < f_out:
            return out_grads  # n x f_out
        else:
            return cls.cov_kron_B(module, out_grads)  # f_out x f_out

    @classmethod
    def cov_kfe_A(cls, module, in_data):
        n, f_in = in_data.shape
        if n < f_in:
            _, _, Vt = torch.linalg.svd(in_data, full_matrices=True)
            return Vt.T  # f_in x f_in
        else:
            A = cls.cov_kron_A(module, in_data)
            _, U = torch.linalg.eigh(A)
            return U  # f_in x f_in

    @classmethod
    def cov_kfe_B(cls, module, out_grads):
        n, f_out = out_grads.shape
        if n < f_out:
            _, _, Vt = torch.linalg.svd(out_grads, full_matrices=True)
            return Vt.T  # f_out x f_out
        else:
            B = cls.cov_kron_B(module, out_grads)
            _, U = torch.linalg.eigh(B)
            return U  # f_out x f_out

    @classmethod
    def cov_kfe_scale(cls, module, in_data, out_grads, Ua, Ub, bias=True):
        n, f_in = in_data.shape
        _, f_out = out_grads.shape
        in_data_kfe = in_data.mm(Ua)
        out_grads_kfe = out_grads.mm(Ub)
        scale_w = torch.mm(out_grads_kfe.T ** 2, in_data_kfe ** 2) / n
        if bias:
            scale_b = (out_grads_kfe ** 2).mean(dim=0)
            return scale_w, scale_b
        return scale_w,

    @staticmethod
    def cov_unit_wise(module, in_data, out_grads):
        n, f_in = in_data.shape[0], in_data.shape[1]
        in_in = torch.bmm(in_data.unsqueeze(2), in_data.unsqueeze(1)).view(n, -1)  # n x (f_in x f_in)
        grad_grad = out_grads.mul(out_grads)  # n x f_out
        return torch.matmul(grad_grad.T, in_in).view(-1, f_in, f_in)  # f_out x f_in x_fin

    @staticmethod
    def gram_A(module, in_data1, in_data2=None):
        if in_data2 is None:
            return torch.matmul(in_data1, in_data1.T)  # n x n
        return torch.matmul(in_data1, in_data2.T)  # n x n

    @staticmethod
    def gram_B(module, out_grads1, out_grads2=None):
        if out_grads2 is None:
            return torch.matmul(out_grads1, out_grads1.T)  # n x n
        return torch.matmul(out_grads1, out_grads2.T)  # n x n

    @staticmethod
    def rfim_relu(module, in_data, out_data):
        nu = torch.sigmoid(out_data) ** 2  # n x f_out
        xxt = torch.einsum('bi,bj->bij', in_data, in_data)  # n x f_in x f_in
        return torch.einsum('bi,bjk->ijk', nu, xxt)  # f_out x f_in x f_in

    @staticmethod
    def rfim_softmax(module, in_data, out_data):
        # equivalent to fisher_exact_for_cross_entropy
        probs = F.softmax(out_data, dim=1)  # n x f_out
        ppt = torch.bmm(probs.unsqueeze(2), probs.unsqueeze(1))  # n x f_out x f_out
        diag_p = torch.stack([torch.diag(p) for p in probs], dim=0)  # n x f_out x f_out
        f = diag_p - ppt  # n x f_out x f_out
        xxt = torch.einsum('bi,bj->bij', in_data, in_data)  # n x f_in x f_in
        return torch.einsum('bij,bkl->ikjl', f, xxt)  # (f_out)(f_in)(f_out)(f_in)

    @staticmethod
    def in_data_mean(module, in_data):
        return in_data.mean(dim=0)  # f_in

    @staticmethod
    def out_data_mean(module, out_data):
        return out_data.mean(dim=0)  # f_out

    @staticmethod
    def out_spatial_size(module, out_data):
        return 1

    @staticmethod
    def out_grads_mean(module, out_grads):
        return out_grads.mean(dim=0)  # f_out

    @staticmethod
    def bfgs_kron_s_As(module, in_data):
        H = module.bfgs.kron.A_inv  # f_in x f_in
        s = torch.mv(H, in_data.mean(dim=0))  # fin
        indata_s = torch.mv(in_data, s)  # n
        As = torch.mv(in_data.T, indata_s)  # f_in
        return s, As
