import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from .general import rebuild_tucker, FUNC_LIST

def weight_gen(org_weight, rank, tucker=True):
    """### weight_gen

    Args:
        org_weight (torch.Tensor): the weight tensor
        rank (int): low rank

    Returns:
        torch.Tensor: down, up[, mid]
    """
    out_dim, in_dim, *k = org_weight.shape
    if k and tucker:
        down = torch.empty(rank, in_dim, *(1 for _ in k))
        up = torch.empty(out_dim, rank, *(1 for _ in k))
        mid = torch.empty(rank, rank, *k)
        nn.init.kaiming_uniform_(down, a=math.sqrt(5))
        nn.init.constant_(up, 0)
        nn.init.kaiming_uniform_(mid, a=math.sqrt(5))
        return down, up, mid
    else:
        down = torch.empty(rank, in_dim)
        up = torch.empty(out_dim, rank)
        nn.init.kaiming_uniform_(down, a=math.sqrt(5))
        nn.init.constant_(up, 0)
        return down, up


def diff_weight(d, u, m=None, gamma=1.0):
    """### diff_weight

    Get ΔW = BA, where BA is low rank decomposition

    Args:
        d (torch.Tensor): weight of down proj linear/conv layer
        u (torch.Tensor): weight of up proj linear/conv layer
        m (torch.Tensor, optional): middle weight of tucker decomposition
        gamma (float, optional): scale factor, normally alpha/rank here

    Returns:
        torch.Tensor: ΔW
    """
    R, I, *k = d.shape
    O, R, *_ = u.shape
    u = u * gamma

    if m is None:
        result = u.reshape(-1, u.size(1)) @ d.reshape(d.size(0), -1)
    else:
        R, R, *k = m.shape
        u = u.reshape(u.size(0), -1).transpose(0, 1)
        d = d.reshape(d.size(0), -1)
        result = rebuild_tucker(m, u, d)
    return result.reshape(O, I, *k)


def bypass_forward_diff(x, d, u, m=None, gamma=1.0, extra_args={}):
    """### bypass_forward_diff

    Args:
        x (torch.Tensor): input tensor
        d (torch.Tensor): weight of down proj linear/conv layer
        u (torch.Tensor): weight of up proj linear/conv layer
        m (torch.Tensor, optional): middle weight of tucker decomposition
        gamma (float, optional): scale factor, normally alpha/rank here
        extra_args (dict, optional): extra args for forward func, \
            e.g. padding, stride for Conv1/2/3d

    Returns:
        torch.Tensor: output tensor
    """
    if m is not None:
        down = FUNC_LIST[d.dim()](x, d)
        mid = FUNC_LIST[d.dim()](down, m, **extra_args)
        up = FUNC_LIST[d.dim()](mid, u)
    else:
        down = FUNC_LIST[d.dim()](x, d, **extra_args)
        up = FUNC_LIST[d.dim()](down, u)
    return up * gamma
