import torch
import torch.nn as nn
import torch.optim as optim

import argparse
import numpy as np


def get_circ_index(rc, cc, k, sym=False):
    rc += rc % k
    cc += cc % k
    assert rc % k == 0, f"rc={rc}, k={k}"
    assert cc % k == 0, f"cc={cc}, k={k}"
    
    rc = int((rc+k-1)/k) * k
    cc = int((cc+k-1)/k) * k
    i = np.arange(0,k,1).reshape([1,k])
    j = np.arange(0,-k,-1).reshape([k,1])
    # to follow the caffe implementation
    #indx = i + j
    indx = (i + j).T
    indx = (indx + k) % k
    m = np.tile(indx, [int(rc/k), int(cc/k)])
    offset = np.arange(0,rc*cc)
    i = (offset // cc) // k
    j = (offset % cc) // k
    offset = (i * cc + j * k).reshape([rc,cc])
    return (m + offset).astype(np.int64)


class BlockCircMV(torch.autograd.Function):


    @staticmethod
    def forward(x, w, bias):
        """
        fft_w: [1, r, s, circ]
        fft_x: [b, s * circ]
        circ: circulant block size
        """
        _, r, s, circ = w.shape
        fft_x = torch.fft.fft(x.view([-1, 1, s, circ]))
        fft_w = torch.fft.fft(w)
        output = torch.fft.ifft(fft_w * fft_x).real
        output = output.sum(dim=2).view([-1, r * circ])
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output



    @staticmethod
    def setup_context(ctx, inputs, output): 
        x, w, bias = inputs
        _, r, s, circ = w.shape
        fft_x = torch.fft.fft(x.view([-1, 1, s, circ]))
        fft_w = torch.fft.fft(w)
        ctx.save_for_backward(fft_x, fft_w, bias)

    @staticmethod
    def backward(ctx, grad_output):
        fft_x, fft_w, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        _, rows, cols, _ = fft_w.shape
        circ = grad_output.shape[1] // rows
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
        fft_o = torch.fft.fft(grad_output.view([-1, rows, 1, circ]))
        if ctx.needs_input_grad[1]:
            grad_weight = torch.fft.ifft(torch.conj(fft_x) * fft_o).real.sum(dim=0,keepdim=True)
        if ctx.needs_input_grad[0]:
            grad_input = torch.fft.ifft(torch.conj(fft_w) * fft_o).real.sum(dim=1)
            grad_input = grad_input.view([-1, cols * circ])
        return grad_input, grad_weight, grad_bias


class CIRCULANTLinear(nn.Module):
    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        block_size: int = 128,

        use_bias: bool = False,
    ):        
        super(CIRCULANTLinear, self).__init__()
        torch.cuda.reset_peak_memory_stats(device=None)
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.block_size = block_size
        self.use_bias = use_bias
        self.n = max(in_dim, out_dim)
        self.disabled = False
        if self.block_size <= 0:
            raise ValueError(f"`block_size` should be a positive integer value but the value passed is {self.block_size}")
        ind = get_circ_index(self.out_dim, self.in_dim, block_size)
        size = len(np.unique(ind))
        self.rows = (self.out_dim + self.out_dim%block_size) // block_size 
        self.cols = (self.in_dim + self.in_dim%block_size) // block_size
        print("init from random")
        self.circulant = nn.Parameter(torch.randn(size, dtype=torch.float32), requires_grad=True)
        torch.nn.init.kaiming_normal_(
            self.circulant.view([self.rows, self.cols, block_size])
        )
        
        print(
            f"out_features={self.out_dim},"
            f"in_features={self.in_dim},"
            f"block_size={block_size},"
            f"param={size}"
        )
        self.to(torch.float32)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        previous_dtype = x.dtype
        x = x.to(torch.float32)

        block_size = self.block_size
        batch, seq_len, _ = x.shape
        

        y = BlockCircMV.apply(
            x.view([batch * seq_len, -1]),
            self.circulant.view([1, self.rows, self.cols, block_size]),
            None,
        )
        
        y = y[ :, :self.out_dim].view([batch, seq_len, -1])

        return y.to(previous_dtype)
    

def main():
    parser = argparse.ArgumentParser(description="Training script with hyperparameters")
    parser.add_argument('--block_size', type=int, default=1024, help="block size for finetuining")
    args = parser.parse_args()
    block_size = args.block_size
    # block_size = 128
    input_dim = 1024
    output_dim = 1024

    circulant_ft = CIRCULANTLinear(in_dim=input_dim, out_dim=output_dim, block_size=block_size).to("cuda")
    print('ft',torch.cuda.max_memory_allocated(device=None))
    x = torch.randn(1,1,1,1024).to("cuda")
    target =  torch.randn(1,1,1024).to("cuda")
    criterion = nn.MSELoss().to("cuda")
    optimizer = optim.Adam(circulant_ft.parameters(), lr=0.0001)
    print('targr+op_i',torch.cuda.max_memory_allocated(device=None))


    for step, batch_data in enumerate(x):

        print(batch_data.shape)
        output = circulant_ft(batch_data.to("cuda"))
        print("Forward pass output shape:", output.shape)
        print('fw',torch.cuda.max_memory_allocated(device=None))

        optimizer.zero_grad()
        loss = criterion(output, target)
        
        loss.backward()
        print('bw',torch.cuda.max_memory_allocated(device=None))
        optimizer.step()
        print('op_s',torch.cuda.max_memory_allocated(device=None))



if __name__ == "__main__":
    main()