import torch
import numpy as np
import torch.nn as nn
from tensorly.decomposition import parafac as CPD_svd
from tensorly.decomposition._cp_power import parafac_power_iteration as CPD_tpm

class Tensor2DS_decomp(nn.Module):
    """
    Tensor2DS aims to decompose a pretrained regular CONV layer into a depthwise separable layer or CPD initialized PDP kernels
    """

    def __init__(self, conv_nn_module, order='DW_PW', rank=10):
        """
        - conv_nn_module: an nn.Module object, which is the selected pretrained layer.
        - order: a string, which decides the type of decomposition. There are four choices: ['DW_PW', 'PW_DW', 'CPD_svd', 'CPD_tpm'].
        """

        def slice_svd_output(weights_4d, cut_type=order, rank=2):  # 2 is the default rank
            C_out, C_in, k, _ = weights_4d.shape
            weights_3d = weights_4d.reshape(C_out, C_in, k * k)  # [C_out, C_in, k^2]

            if cut_type == 'DW_PW':
                # init dw_set and pw_set
                pw_set = torch.zeros(C_out, C_in)
                dw_set = torch.zeros(C_in, k * k)

                for i in range(0, C_in):
                    temp_slice = weights_3d[:, i, :]  # [C_out, k^2]
                    U, s, V = torch.svd(temp_slice)  # [C_out, #SVs], SVs, [k^2, #SVs]
                    pw_set[:, i] = U[:, 1] * s[1]
                    dw_set[i, :] = V.T[1, :]

            elif cut_type == 'PW_DW':
                # init dw_set and pw_set
                pw_set = torch.zeros(C_in, C_out)
                dw_set = torch.zeros(C_out, k * k)

                for i in range(0, C_out):
                    temp_slice = weights_3d[i, :, :]  # [C_in, k^2]
                    U, s, V = torch.svd(temp_slice)  # [C_in, #SVs], SVs, [k^2, #SVs]
                    pw_set[:, i] = U[:, 1] * s[1]
                    dw_set[i, :] = V.T[1, :]

            elif cut_type == 'CPD_svd' or cut_type == 'CPD_tpm':
                # init dw_set, pw1_set and pw2_set
                pw1_set = torch.zeros(C_in, rank)  # [C_in, rank]
                dw_set = torch.zeros(rank, k * k)  # [rank, k^2]
                pw2_set = torch.zeros(rank, C_out)  # [rank, C_out]
                
                # print(pw1_set.shape, dw_set.shape, pw2_set.shape)

                # CPD decomposition
                if cut_type == 'CPD_svd':
                    (_, factors), errors = CPD_svd(weights_3d.numpy(), rank=rank, return_errors=True)  # [C_out, C_in, k*k]
                    # error = errors[-1]
                    # print('Approximation Error: ', error)
                elif cut_type == 'CPD_tpm':
                    (_, factors) = CPD_tpm(weights_3d.numpy(), rank=rank)

                factors_len = len(factors)
                for j in range(0, factors_len):
                    factors[j] = torch.tensor(factors[j])

                for i in range(0, rank):
                    pw1_set[:, i] = factors[1][:, i]
                    dw_set[i, :] = factors[2][:, i]
                    pw2_set[i, :] = factors[0][:, i]
                
                pw_set = (pw1_set, pw2_set)

            return dw_set, pw_set

        def ds_init(dw_set, pw_set, init_type=order, bias_or_not=True):
            if init_type == 'DW_PW':
                C_out, C_in = pw_set.shape  # [C_out, C_in]
                _, k_square = dw_set.shape  # [C_in, k^2]
                k = int(np.sqrt(k_square))

                # create the depthwise and pointwise layers
                if bias_or_not:
                    conv_dw = nn.Conv2d(in_channels=C_in, out_channels=C_in, kernel_size=(k, k), groups=C_in, padding=self.padding, stride=self.stride)
                    conv_pw = nn.Conv2d(in_channels=C_in, out_channels=C_out, kernel_size=1)
                else:
                    conv_dw = nn.Conv2d(in_channels=C_in, out_channels=C_in, kernel_size=(k, k), groups=C_in,
                                        bias=False, stride=self.stride, padding=self.padding)
                    conv_pw = nn.Conv2d(in_channels=C_in, out_channels=C_out, kernel_size=1, bias=False)

                # init the DW layer
                dw_set_4d = dw_set.reshape(C_in, 1, k, k)
                conv_pw.weight.data = dw_set_4d

                # init the PW layer
                pw_set_4d = pw_set.reshape(C_out, C_in, 1, 1)
                conv_pw.weight.data = pw_set_4d


            elif init_type == 'PW_DW':
                C_in, C_out = pw_set.shape  # [C_in, C_out]
                _, k_square = dw_set.shape  # [C_out, k^2]
                k = int(np.sqrt(k_square))

                # create the depthwise and pointwise layers
                if bias_or_not:
                    conv_pw = nn.Conv2d(in_channels=C_in, out_channels=C_out, kernel_size=1)
                    conv_dw = nn.Conv2d(in_channels=C_out, out_channels=C_out, kernel_size=(k, k), groups=C_out, stride=self.stride, padding=self.padding)

                else:
                    conv_pw = nn.Conv2d(in_channels=C_in, out_channels=C_out, kernel_size=1, bias=False)
                    conv_dw = nn.Conv2d(in_channels=C_out, out_channels=C_out, kernel_size=(k, k), groups=C_out,
                                        bias=False, stride=self.stride, padding=self.padding)

                # init the PW layer
                pw_set_4d = pw_set.reshape(C_out, C_in, 1, 1)
                conv_pw.weight.data = pw_set_4d

                # init the DW layer
                dw_set_4d = dw_set.reshape(C_out, 1, k, k)
                conv_dw.weight.data = dw_set_4d

            elif init_type == 'CPD_svd' or init_type == 'CPD_tpm':
                pw1_set, pw2_set = pw_set
                C_in, rank = pw1_set.shape  # [C_in, rank]
                _, k_square = dw_set.shape  # [rank, k*k]
                _, C_out = pw2_set.shape  # [rank, C_out]
                # print(rank)
                k = int(np.sqrt(k_square))

                # create the pointwise and depthwise layers
                if bias_or_not:
                    conv_pw1 = nn.Conv2d(in_channels=C_in, out_channels=rank, kernel_size=1)
                    conv_dw = nn.Conv2d(in_channels=rank, out_channels=rank, kernel_size=(k, k), groups=rank, stride=self.stride, padding=self.padding)
                    conv_pw2 = nn.Conv2d(in_channels=rank, out_channels=C_out, kernel_size=1)
                else:
                    conv_pw1 = nn.Conv2d(in_channels=C_in, out_channels=rank, kernel_size=1, bias=False)
                    conv_dw = nn.Conv2d(in_channels=rank, out_channels=rank, kernel_size=(k, k), groups=rank,
                                        bias=False, stride=self.stride, padding=self.padding)
                    conv_pw2 = nn.Conv2d(in_channels=rank, out_channels=C_out, kernel_size=1, bias=False)

                # init the pw1 layer
                pw1_set_4d = pw1_set.T.reshape(rank, C_in, 1, 1)
                conv_pw1.weight.data = pw1_set_4d

                # init the dw layer
                dw_set_4d = dw_set.reshape(rank, 1, k, k)
                conv_dw.weight.data = dw_set_4d

                # init the pw2 layer
                pw2_set_4d = pw2_set.T.reshape(C_out, rank, 1, 1)
                conv_pw2.weight.data = pw2_set_4d

                conv_pw = (conv_pw1, conv_pw2)

            return conv_dw, conv_pw

        super().__init__()

        self.order = order
        if getattr(conv_nn_module, 'bias') is not None:
            self.bias_original = True
        else:
            self.bias_original = False

        self.stride = conv_nn_module.stride
        self.padding = conv_nn_module.padding
        self.rank = rank

        # get dw_set and pw_set
        weights_4d = conv_nn_module.weight.data
        dw_set, pw_set = slice_svd_output(weights_4d, cut_type=self.order, rank=self.rank)

        # get conv_dw and conv_pw
        self.conv_dw, self.conv_pw = ds_init(dw_set, pw_set, init_type=self.order, bias_or_not=self.bias_original)

        # init the bias
        if self.bias_original:
            if self.order == 'DW_PW':
                self.conv_pw.bias.data = conv_nn_module.bias.data
            elif self.order == 'PW_DW':
                self.conv_pw.bias.data = conv_nn_module.bias.data
            elif self.order == 'CPD_svd' or self.order == 'CPD_tpm':
                self.conv_pw1, self.conv_pw2 = self.conv_pw
                self.conv_pw2.bias.data = conv_nn_module.bias.data
        else:
            if self.order == 'CPD_svd' or self.order == 'CPD_tpm': 
                self.conv_pw1, self.conv_pw2 = self.conv_pw

    def forward(self, x):
        if self.order == 'DW_PW':
            x = self.conv_dw(x)
            out = self.conv_pw(x)

        if self.order == 'PW_DW':
            x = self.conv_pw(x)
            out = self.conv_dw(x)

        if self.order == 'CPD_svd' or self.order == 'CPD_tpm':
            out = self.conv_pw2(self.conv_dw(self.conv_pw1(x)))
        return out
