
import torch
from torch import nn
import numpy as np
from typing import Dict, Tuple
from overrides import overrides
from torch.nn import Parameter
import torch.nn.functional as F

import matplotlib.pyplot as plt
import io
import copy

from  .flow import Flow

def get_circ_index(rc, cc, k, sym=False):
    assert rc % k == 0, f"rc={rc}, k={k}"
    assert cc % k == 0, f"cc={cc}, k={k}"
    rc = int((rc+k-1)/k) * k
    cc = int((cc+k-1)/k) * k
    i = np.arange(0,k,1).reshape([1,k])
    j = np.arange(0,-k,-1).reshape([k,1])
    # to follow the caffe implementation
    #indx = i + j
    indx = (i + j).T
    indx = (indx + k) % k
    m = np.tile(indx, [int(rc/k), int(cc/k)])
    offset = np.arange(0,rc*cc)
    i = (offset // cc) // k
    j = (offset % cc) // k
    offset = (i * cc + j * k).reshape([rc,cc])
    return (m + offset).astype(np.int64)

def cdvft_mv(diag_vecs, circ_vecs, x, topn=1e9):

    output = x
    device = x.device
    n2_1 = len(diag_vecs) + len(circ_vecs)
    topn = min(n2_1, topn)
    for wi in range(n2_1-1, n2_1-topn, -2):
        output = output * diag_vecs[wi//2].to(device).unsqueeze(0)
        output = torch.fft.irfft(
            torch.fft.rfft(output) * circ_vecs[(wi-1)//2].to(device).unsqueeze(0),n=output.shape[-1]
        )

    if topn % 2 > 0:
        output = output * diag_vecs[(n2_1-topn)//2].to(device).unsqueeze(0)
    
    return output

class CDVFT_fftw(torch.autograd.Function):


    # Note that forward, setup_context, and backward are @staticmethods
    @staticmethod
    def forward(ctx, x, bias, diag_vecs, circ_vecs):
        output = cdvft_mv(diag_vecs, circ_vecs, x)
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        ctx.save_for_backward(x, bias, diag_vecs, circ_vecs)
        return output


    @staticmethod
    def setup_context(ctx, inputs, output):
        x, bias, diag_vecs, circ_vecs = inputs
        to_save = [x, bias, diag_vecs, circ_vecs]
        ctx.save_for_backward(*to_save)


    @staticmethod
    def backward(ctx, grad_output):
        x, bias, diag_vecs, circ_vecs = ctx.saved_tensors
        grad_input = grad_diag = grad_circ = grad_bias = None

        if bias is not None and ctx.needs_input_grad[1]:
            grad_bias = grad_output.sum(0)

        # print("grad_output: ", grad_output)

        # need to loop through all matrices
        n2_1 = len(diag_vecs) + len(circ_vecs)
        circ = diag_vecs.shape[1]
        grad_diag_list = []
        grad_circ_list = []

        if ctx.needs_input_grad[0] or ctx.needs_input_grad[2]:
            for i in reversed(range(n2_1)):  # 2n-1, 2n-2, ..., 1, 0
                wi = (n2_1 - i - 1) // 2
                tmp_x = cdvft_mv(diag_vecs, circ_vecs, x, i)  # this should use part of forward results, major memory cost
                if i % 2 == 0:  # diag
                    grad_input = grad_output * diag_vecs[wi].unsqueeze(0)
                    grad_weight = torch.sum(grad_output * tmp_x, dim=0)
                    grad_diag_list.append(grad_weight)
                    # print(f"grad_weight (diag {wi}): ", grad_weight)
                else:
                    # single block
                    fft_w = circ_vecs[wi].unsqueeze(0)
                    fft_x = torch.fft.rfft(tmp_x)
                    fft_o = torch.fft.rfft(grad_output)
                    grad_input = torch.fft.irfft(torch.conj(fft_w) * fft_o)
                    grad_weight = torch.conj(fft_x) * fft_o
                    # following is related to the rfft implementation constant
                    grad_weight = grad_weight / circ
                    grad_weight[:,1:-1] = grad_weight[:,1:-1] * 2

                    grad_circ_list.append(grad_weight)
                    # print(f"grad_weight (circ {wi}): ", grad_weight)

                # go to next matrix
                grad_output = grad_input

        grad_diag = torch.stack(grad_diag_list)
        grad_circ = torch.stack(grad_circ_list).sum(dim=1)
        # wrap up everything
        if not ctx.needs_input_grad[0]:
            grad_input = None
        if not ctx.needs_input_grad[2]:
            grad_diag = None
        if not ctx.needs_input_grad[3]:
            grad_circ = None
        # print(grad_circ.mean())
        # print(circ_vecs.mean())
        return grad_input, grad_bias, grad_diag, grad_circ

class CDVFT(torch.autograd.Function):

    # Note that forward, setup_context, and backward are @staticmethods
    @staticmethod
    def forward(ctx,x, bias, diag_vecs, circ_vecs):
        circ_vecs_fftw = [torch.fft.rfft(circ_vec) for circ_vec in circ_vecs]
        output = cdvft_mv(diag_vecs, circ_vecs_fftw, x)
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        ctx.save_for_backward(x, bias, diag_vecs, circ_vecs)
        return output


    @staticmethod
    def setup_context(ctx, inputs, output):
        x, bias, diag_vecs, circ_vecs = inputs
        to_save = [x, bias, diag_vecs, circ_vecs]
        ctx.save_for_backward(*to_save)


    @staticmethod
    def backward(ctx, grad_output):
        x, bias, diag_vecs, circ_vecs = ctx.saved_tensors
        grad_input = grad_diag = grad_circ = grad_bias = None

        if bias is not None and ctx.needs_input_grad[1]:
            grad_bias = grad_output.sum(0)

        print("grad_output: ", grad_output)

        # need to loop through all matrices
        n2_1 = len(diag_vecs) + len(circ_vecs)
        circ = diag_vecs.shape[1]
        grad_diag_list = []
        grad_circ_list = []

        circ_vecs_fftw = [torch.fft.rfft(circ_vec) for circ_vec in circ_vecs]

        if ctx.needs_input_grad[0] or ctx.needs_input_grad[2]:
            for i in reversed(range(n2_1)):  # 2n-1, 2n-2, ..., 1, 0
                wi = (n2_1 - i - 1) // 2
                tmp_x = cdvft_mv(diag_vecs, circ_vecs_fftw, x, i)  # this should use part of forward results, major memory cost
                if i % 2 == 0:  # diag
                    grad_input = grad_output * diag_vecs[wi].unsqueeze(0)
                    grad_weight = torch.sum(grad_output * tmp_x, dim=0)
                    grad_diag_list.append(grad_weight)
                    print(f"grad_weight (diag {wi}): ", grad_weight)
                else:
                    # single block
                    fft_w = circ_vecs_fftw[wi].unsqueeze(0)
                    # print("fft_w: ", fft_w)
                    fft_x = torch.fft.rfft(tmp_x)
                    fft_o = torch.fft.rfft(grad_output)
                    grad_input = torch.fft.irfft(torch.conj(fft_w) * fft_o)
                    grad_weight = torch.conj(fft_x) * fft_o
                    # following is related to the rfft implementation constant
                    # grad_weight = grad_weight / circ
                    # grad_weight[:,1:-1] = grad_weight[:,1:-1] * 2
                    grad_weight = grad_weight.sum(dim=0)
                    grad_weight = torch.fft.irfft(grad_weight)
                    grad_circ_list.append(grad_weight)
                    print(f"grad_weight (circ {wi}): ", grad_weight)

                # go to next matrix
                grad_output = grad_input

        grad_diag = torch.stack(grad_diag_list)
        grad_circ = torch.stack(grad_circ_list)
        # grad_circ = torch.stack(grad_circ_list).sum(dim=1)
        # wrap up everything
        if not ctx.needs_input_grad[0]:
            grad_input = None
        if not ctx.needs_input_grad[2]:
            grad_diag = None
        if not ctx.needs_input_grad[3]:
            grad_circ = None
        
        return grad_input, grad_bias, grad_diag, grad_circ

class CDVFTLinear(nn.Module):
    def __init__(self, input_features, output_features, n2_1=2, bias=False):
        super().__init__()
        self.input_features = input_features
        self.output_features = output_features
        assert self.input_features > 0 and self.output_features > 0
        
        self.circ = max([self.input_features, self.output_features])
        self.n2_1 = n2_1
        assert self.n2_1 % 2, f"odd number only for n2_1={n2_1}"

        # [d0, c0, d1, c1, d2, ...]

        # be careful about the random initialization
        # it can easily result in a large condition number in final matrix
        # we want to control the final conditional number for stability purpose
        weight_list = []
        for i in range(self.n2_1):
            if i % 2 == 0:
                diag_w = torch.randn(self.circ) * 1/self.circ + 1 # N(1, I/self.circ^2)
                weight_list.append(diag_w)
            else:
                fft_w = torch.fft.rfft(torch.randn(self.circ)) / self.circ # N(0, DFT.T@DFT) = N(0, I)
                fft_w = fft_w + 1 # N(1, I)
                weight_list.append(fft_w)

        
        #self.weight = nn.ParameterList([nn.Parameter(w) for w in weight_list])
        self.circ_weight = nn.Parameter(torch.stack([w for i, w in enumerate(weight_list) if i % 2 > 0]))
        self.diag_weight = nn.Parameter(torch.stack([w for i, w in enumerate(weight_list) if i % 2 == 0]))
        # print(f"[DEBUG] Step 1: self.circ_weight initialized, shape: {self.circ_weight.shape}")
        
        if bias:
            self.bias = nn.Parameter(torch.randn(self.circ))
        else:
            # You should always register all possible parameters, but the
            # optional ones can be None if you want.
            self.register_parameter('bias', None)


    # set initialization
    def init_from_mat(self, mat):

        assert mat.shape[0] == self.circ and mat.shape[1] == self.circ
        
        # set everything to be identity
        weight_list = []
        for i in range(self.n2_1):
            if i % 2 == 0:
                diag_w = torch.tensor([1.]*self.circ, dtype=torch.float)
                weight_list.append(diag_w)
            else:
                fft_w = torch.fft.rfft(torch.tensor([1.]+[0]*(self.circ-1), dtype=torch.float))
                weight_list.append(fft_w)

        # set last circulant matrix to be approximated circulant matrix
        ind = get_circ_index(self.circ, self.circ, self.circ)
        w = mat.detach().cpu().float().numpy()
        circ_w = np.bincount(ind.flatten(), weights=w.flatten()) / float(self.circ)
        fft_w = torch.fft.rfft(torch.tensor(circ_w, dtype=torch.float))
        weight_list[-2] = fft_w

        
        # w = d0 @ c0 @ d1 @ c1 @ d2
        # d0 = approx( w @ (c0 @ d1 @ c1 @ d2)^{-1} )
        # c0 = approx( d0^{-1} @ w @ (d1 @ c1 @ d2)^{-1 )
        # d1 = approx( (d0 @ c0)^{-1} @ w @ (c1 @ d2)^{-1} )
        # ...

        self.circ_weight = nn.Parameter(torch.stack([w for i, w in enumerate(weight_list) if i % 2 > 0]))
        self.diag_weight = nn.Parameter(torch.stack([w for i, w in enumerate(weight_list) if i % 2 == 0]))


    
    def forward(self, input):
        # See the autograd section for explanation of what happens here.
        self.eigen_normalization()
        res = CDVFT_fftw.apply(input, self.bias, self.diag_weight, self.circ_weight)
        return res


    def logabsdet(self,h,w):
        # fft_w is from rfft, which is symmetric
        res = torch.log(self.circ_weight.abs())
        res = res.sum() + res[:, 1:-1].sum()
        res += torch.log(self.diag_weight.abs()).sum()
        return res*h*w

    def logabsdet2(self,h,w):
        # fft_w is not from rfft, which is symmetric
        # res = torch.log(torch.fft.rfft(self.circ_weight).cpu().abs())
        res = torch.log(self.circ_weight.cpu().abs())
        res = res.sum() * 2 - res[:,0].sum() - res[:,-1].sum()
        res += torch.log(self.diag_weight.cpu().abs()).sum()
        return res*h*w

    def reverse(self, y, weight_list=None):
        # return x given y, considering bias
        # y = x @ mat.T + b
        # (y - b) @ (mat.T)^{-1} = x
        # (y - b) @ (mat^{-1}).T = x
        # mat = d0 @ c0 @ d1 @ c1 @ d2
        # mat^{-1} = d2^{-1} @ c1^{-1} @ ... @ d0^{-1}
        # 
        # if input is identity matrix, then it returns (mat^{-1}).T
        if self.bias is not None:
            y = y - self.bias

        x = CDVFT_fftw.apply(y, None,  reversed(1 / self.diag_weight), reversed(1 / self.circ_weight))
        return x


    def get_matT(self):
        # w = [d0, c0, d1, c1, d2, ..]
        # mat = d0 @ c0 @ d1 @ c1 @ d2
        # forward(I) = I @ mat.T = mat.T
        # return mat.T
        res = CDVFT_fftw.apply(torch.eye(self.circ), None, self.diag_weight, self.circ_weight)
        return res


    def get_invT(self):  # only generate mat inverse transpose, not considering bias
        # return mat^{-1}.T
        res = CDVFT_fftw.apply(torch.eye(self.circ), None, reversed(1 / self.diag_weight), reversed(1 / self.circ_weight))
        return res

    
    def extra_repr(self):
        # (Optional)Set the extra information about this module. You can test
        # it by printing an object of this class.
        return 'input_features={}, output_features={}, bias={}, circ={}'.format(
            self.input_features, self.output_features, self.bias is not None, self.circ
        )
    
    def eigen_normalization(self):
        with torch.no_grad():
            self.circ_weight.div_(self.circ_weight.abs().max())
            self.diag_weight.div_(self.diag_weight.abs().max())


    
class Conv2dCDVFT(Flow):
    def __init__(self, in_channels, inverse=False, n2_1=3):
        super(Conv2dCDVFT, self).__init__(inverse)
        self.in_channels = in_channels
        self.cdvft = CDVFTLinear(in_channels, in_channels,n2_1)  # 使用 CDVFT 变换

    @overrides
    def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            input: Tensor [batch, in_channels, H, W]
        Returns:
            out: [batch, in_channels, H, W], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`
        """
        batch, c, h, w = input.shape
        out = input.permute(0, 2, 3, 1).reshape(-1, c)  # 调整形状
        out = self.cdvft.forward(out)  # 应用 CDVFT 变换
        out = out.reshape(batch, h, w, c).permute(0, 3, 1, 2)  # 还原形状
        logdet = self.cdvft.logabsdet(h, w)  # 计算 log determinant
        return out, logdet

    @overrides
    def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            input: Tensor [batch, in_channels, H, W]
        Returns:
            out: [batch, in_channels, H, W], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`
        """
        batch, c, h, w = input.shape
        out = input.permute(0, 2, 3, 1).reshape(-1, c)  # 调整形状
        out = self.cdvft.reverse(out)  # 计算逆变换
        out = out.reshape(batch, h, w, c).permute(0, 3, 1, 2)  # 还原形状
        logdet = -self.cdvft.logabsdet(h, w)  # 计算 log determinant
        return out, logdet

    @overrides
    def init(self, data, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        with torch.no_grad():
            return self.forward(data)
        
    def sync(self):  # todo: check the right way of implementing sync
        pass
        # self.weight_inv.copy_(self.weight.data.inverse()) # this is the one in 1x1 conv

    def visualize_weights(self):
        """
        可视化 CDVFT 变换的权重，包括 diag_weight 和 circ_weight。
        """
        ret = []

        # 提取 diagonal 权重
        diag_weights = self.cdvft.diag_weight.detach().cpu().numpy()  # [n2_1 // 2, in_channels]
        circ_weights = self.cdvft.circ_weight.detach().cpu().numpy()  # [n2_1 // 2, in_channels]

        # 遍历 `diag_weight` 和 `circ_weight`
        for i, (diag_w, circ_w) in enumerate(zip(diag_weights, circ_weights)):
            fig, axs = plt.subplots(1, 2, figsize=(12, 6))

            # 可视化 diagonal 权重
            im1 = axs[0].imshow(np.expand_dims(diag_w, axis=0), cmap='viridis', aspect="auto")
            axs[0].set_title(f"Diag Weight {i}")
            fig.colorbar(im1, ax=axs[0])

            # 可视化 circulant 权重（取实部）
            im2 = axs[1].imshow(np.expand_dims(circ_w.real, axis=0), cmap='viridis', aspect="auto")
            axs[1].set_title(f"Circ Weight {i} (Real Part)")
            fig.colorbar(im2, ax=axs[1])

            # 将图像保存到 `numpy` 数组
            with io.BytesIO() as buff:
                fig.savefig(buff, format='raw')
                buff.seek(0)
                data = np.frombuffer(buff.getvalue(), dtype=np.uint8)

            # 获取图像尺寸
            w, h = fig.canvas.get_width_height()
            img_array = data.reshape((int(h), int(w), -1))

            plt.close(fig)  # 关闭图像，防止占用内存
            ret.append(img_array)  # 存入结果列表

        return ret


    @overrides
    def extra_repr(self):
        return f'inverse={self.inverse}, in_channels={self.in_channels}'

    @classmethod
    def from_params(cls, params: Dict) -> "Conv2dCDVFT":
        return Conv2dCDVFT(**params)