import torch
from torch import nn
import torch_butterfly
import math


class KMatrix(nn.Module):
    def __init__(self, in_size, in_ch, out_ch, kernel_size, param_filt=True, 
            randinit=False, tied=False, nblocks=1):
        super(KMatrix, self).__init__()

        self.in_size = in_size
        self.in_ch = in_ch
        self.out_ch = out_ch

        self.r2c = torch_butterfly.complex_utils.Real2Complex()
        self.c2r = torch_butterfly.complex_utils.Complex2Real()

        b1s = [torch_butterfly.Butterfly(self.in_size, self.in_size, 
                bias=False, complex=True)
               for _ in range(self.out_ch * self.in_ch)]
        b2s = [torch_butterfly.Butterfly(self.in_size, self.in_size, 
                bias=False, complex=True)
               for _ in range(self.out_ch * self.in_ch)]

        self.b1_bmm = torch_butterfly.ButterflyBmm(self.in_size, self.in_size, 
            matrix_batch=self.out_ch * self.in_ch, bias=False, complex=True)
        with torch.no_grad():
            self.b1_bmm.twiddle.copy_(torch.cat([b1.twiddle for b1 in b1s]))
        self.b2_bmm = torch_butterfly.ButterflyBmm(self.in_size, self.in_size, 
            matrix_batch=self.out_ch * self.in_ch, bias=False, complex=True)
        with torch.no_grad():
            self.b2_bmm.twiddle.copy_(torch.cat([b2.twiddle for b2 in b2s]))

    def forward(self, x):
        x = self.r2c(x)

        batch_size = x.shape[0]
        input_reshaped = x.transpose(1, 2).reshape(
            batch_size, self.in_size, 1, self.in_ch, self.in_size)
        input_expanded = input_reshaped.expand(
            batch_size, self.in_size, self.out_ch, self.in_ch, self.in_size)
        out_bmm = self.b1_bmm(input_expanded.reshape(
            batch_size, self.in_size, self.out_ch * self.in_ch, self.in_size))
        out_bmm = out_bmm.transpose(1, 3)
        # (batch_size, self.in_size, self.out_ch * self.in_ch, self.in_size)
        out_bmm = self.b2_bmm(out_bmm)
        # (batch_size, self.in_size, self.out_ch * self.in_ch, self.in_size)
        out_bmm = out_bmm.permute(0, 2, 3, 1)
        # (batch_size, self.out_ch * in_channels, self.in_size, self.in_size)
        out_bmm = out_bmm.reshape(
            batch_size, self.out_ch, self.in_ch, self.in_size, self.in_size)

        x = out_bmm.sum(dim=2) / math.sqrt(self.in_ch)
        x = self.c2r(x)

        return x
