import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import List, Union

import numpy as np

def factorize(n: int, bias=0) -> List[int]:
    # """Return the most average two factorization of n."""
    for i in range(int(np.sqrt(n)) + 1, 1, -1):
        if n % i == 0:
            if bias == 0:
                return [i, n // i]
            else:
                bias -= 1
    return [n, 1]

def kron(a: Union[torch.Tensor, np.ndarray], b: Union[torch.Tensor, np.ndarray], s: Union[torch.Tensor, np.ndarray]=None) -> torch.Tensor:
    """Kronecker product between factors `a` and `b`

    Args:
        a: First factor
        b: Second factor

    Returns:
        Tensor containing kronecker product between `a` and `b`
    """
    if s is not None:
        assert a.shape[1:] == s.shape, "a and s should have the same shape"
        a = s.unsqueeze(0) * a
    a = torch.from_numpy(a) if isinstance(a, np.ndarray) else a
    b = torch.from_numpy(b) if isinstance(b, np.ndarray) else b

    return torch.stack([torch.kron(a[k], b[k]) for k in range(a.shape[0])]).sum(dim=0)

    
    

class KronConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, patchsize=None, shape_bias=0, structured_sparse=False, bias=True, rank_rate=0.1, rank=0) -> None:
        """Kronecker Convolution Layer

        Args:
            rank (int): _description_
            a_shape (_type_): _description_
            b_shape (_type_): _description_
            structured_sparse (bool, optional): _description_. Defaults to False.
            bias (bool, optional): _description_. Defaults to True.
        """
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        if patchsize is not None:
            assert len(patchsize) == 2, "The pathsize should be a tuple of two integers"
            assert in_channels % patchsize[0] == 0 and out_channels % patchsize[1] == 0, "The input and output features should be divisible by the patchsize"
            a_shape = (in_channels // patchsize[0], out_channels // patchsize[1])
            b_shape = (patchsize[0], patchsize[1])
        else:
            in_shape = factorize(in_channels, shape_bias)
            out_shape = factorize(out_channels, shape_bias)
            a_shape = (in_shape[0] * kernel_size, out_shape[0] * kernel_size)
            b_shape = (in_shape[1], out_shape[1])
        self.rank = rank if rank > 0 else min(a_shape[0], a_shape[1], b_shape[0], b_shape[1]) * rank_rate
        self.rank = int(self.rank) if int(self.rank) > 0 else 1
        
        self.structured_sparse = structured_sparse
        
        if structured_sparse:
            self.s = nn.Parameter(torch.randn( *a_shape), requires_grad=True)
        else:
            self.s = None
        from utils.tensorops import gkpd
        if type(kernel_size) == int:
            kernel_size = (kernel_size, kernel_size)
        weight = nn.Conv2d(in_channels, out_channels, kernel_size).weight
        print(weight.shape)
        self.a, self.b = gkpd(weight, (a_shape[1], a_shape[0], 1, kernel_size[1]), (b_shape[1], b_shape[0], kernel_size[0], kernel_size[1]))
        
        
        nn.init.xavier_uniform_(self.a)
        nn.init.xavier_uniform_(self.b)
        self.a_shape = self.a.shape
        self.b_shape = self.b.shape
        bias_shape = np.multiply(a_shape, b_shape)
        if bias:
            self.bias = nn.Parameter(torch.randn(*bias_shape[1:]), requires_grad=True)
        else:
            self.bias = None
            
    def forward(self, x):
        if self.s is not None:
            weight = kron(self.s * self.a, self.b)
        else:
            
            weight = kron(self.a, self.b)
        weight = weight.view(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[0])
        if self.bias is not None:
            return F.conv2d(x, weight, bias=self.bias, stride=1, padding=0, dilation=1, groups=1)
        else:
            return F.conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)