#%%
import torch
import math
import collections

collections.Iterable = collections.abc.Iterable
collections.Mapping = collections.abc.Mapping
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping


import tltorch
torch.set_default_dtype(torch.float32)


class low_rank_CP(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        groups=1,
        bias=True,
        dilation=1,
        maximal_rank=0.5,
    ) -> None:

        super().__init__()

        self.stride = stride
        self.padding = padding
        self.groups = groups
        self.dilation = dilation
        self.kernel_size = (
            [kernel_size, kernel_size]
            if isinstance(kernel_size, int)
            else list(self.kernel_size)
        )
        self.out_channels = out_channels
        self.in_channels = in_channels
        low_rank_percent = maximal_rank
        self.dims = [self.out_channels, self.in_channels] + list(self.kernel_size)
        minimal_rank = 3
        self.rank = max(int(max(self.dims) * low_rank_percent), minimal_rank)
        self.s = torch.nn.Parameter(torch.randn(self.rank), requires_grad=False)
        self.us = torch.nn.ParameterList(
            [
                torch.nn.Parameter(torch.randn(size=(d, r)), requires_grad=True)
                for d, r in zip(self.dims, [self.rank] * len(self.dims))
            ]
        )
        if bias is not None or bias is not False:
            self.bias = torch.nn.Parameter(
                torch.randn(out_channels), requires_grad=True
            )
        self.weight = torch.nn.Parameter(
            torch.zeros(
                (
                    self.out_channels,
                    self.in_channels,
                    self.kernel_size[0],
                    self.kernel_size[1],
                )
            ),
            requires_grad=False,
        )
        self.reset_parameters()  # parameter initialization

    @torch.no_grad()
    def reset_parameters(self):
        torch.nn.init.uniform_(self.s, b=10)
        for i, u in enumerate(self.us):
            torch.nn.init.kaiming_uniform_(u, a=math.sqrt(5))
            # Orthonormalize bases
            # self.us[i].data, _ = torch.linalg.qr(u.data, 'reduced')

    # @torch.no_grad()
    # def reset_parameters(self):
    #     w = torch.empty(size = [self.out_channels,self.in_channels]+self.kernel_size)
    #     torch.nn.init.kaiming_uniform_(w, a=math.sqrt(5))
    #     if self.bias is not None:
    #         fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(w)
    #         if fan_in != 0:
    #             bound = 1 / math.sqrt(fan_in)
    #             torch.nn.init.uniform_(self.bias, -bound, bound)
    #     cp = parafac(tly.tensor(w),rank = self.rank,verbose=False,tol = 1e-3,n_iter_max=10)
    #     self.s.data = torch.tensor(cp.weights)
    #     for i in range(len(self.dims)):
    #         self.us[i].data = torch.tensor(cp.factors[i])
    #     del w

    # @torch.no_grad()
    # def format_weight(self):
    #     # cp = parafac(tly.tensor(self.weight),rank = self.rank,verbose=False,tol = 1e-2,n_iter_max=10,init='svd',orthogonalise=True,svd = 'randomized_svd')
    #     cp = tntorch.Tensor(
    #         self.weight, ranks_cp=self.rank, verbose=False, device=self.weight.device
    #     )
    #     cp.round_tucker(rmax=self.rank)
    #     # self.s.copy_(torch.tensor(cp.weights,device = self.weight.device))
    #     for i in range(len(self.dims)):
    #         self.us[i][:, : cp.Us[i].shape[1]].copy_(
    #             torch.tensor(cp.Us[i], device=self.weight.device)
    #         )
    #     del self.weight

    def forward(self, input):
        """
        forward phase for the convolutional layer. It has to contain the three different
        phases for the steps 'K','L' and 'S' in order to be optimizable using dlrt.
        Every step is rewritten in terms of the tucker decomposition of the kernel tensor
        """

        # weight = torch.einsum('i,ai,bi,ci,di->abcd',self.s[:self.rank],*[u[:,:self.rank] for u in self.us])
        # print(f'check {list(weight.shape)==self.dims}')
        # result = torch.nn.functional.conv2d(input, weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
        result = tltorch.functional.convolution.cp_conv(
            x=input,
            cp_tensor=tltorch.CPTensor(
                weights=self.s[: self.rank],
                factors=[p[:, : self.rank] for p in self.us],
                rank=self.rank,
            ),
            bias=self.bias,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
        )
        return result

    @torch.no_grad()
    def spectral_prune(self, pruning_idx):
        keep_idx = [el for el in range(len(self.s.data)) if el not in pruning_idx]
        for i in range(len(self.dims)):
            self.us[i].data = self.us[i].data[:, keep_idx]
        self.s.data = self.s.data[keep_idx]
        self.rank = len(keep_idx)

    @torch.no_grad()
    def construct_weight_tensor(self):
        """
        just for debugging purposes, don't use it
        """
        return torch.einsum("i,ai,bi,ci,di->abcd", self.s, *self.us)

    def activate_upper_level(self):
        for p in self.us:
            p.requires_grad = True
        self.s.requires_grad = True

    def activate_lower_level(self):
        for p in self.us:
            p.requires_grad = True
        self.s.requires_grad = False  ### False
        self.s.grad = None

    @torch.no_grad()
    def get_hypergradient(self):
        eps = 1e-4
        hypergradient_diagonal = torch.diag(
            self.us[0][:, : self.rank].T @ self.us[0].grad[:, : self.rank]
            + self.us[1][:, : self.rank].T @ self.us[1].grad[:, : self.rank]
            + self.us[3][:, : self.rank].T @ self.us[2].grad[:, : self.rank]
            + self.us[3][:, : self.rank].T @ self.us[3].grad[:, : self.rank]
        )
        self.s.grad.add_(hypergradient_diagonal / (self.s[: self.rank] + eps))
