import torch
import collections

collections.Iterable = collections.abc.Iterable
collections.Mapping = collections.abc.Mapping
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping

import tltorch
import math

import torch.nn.functional as F
import tensorly as tly

tly.set_backend('pytorch')
torch.set_default_dtype(torch.float32)

import numpy as np

from einops import rearrange

from .tucker_conv_base import _ConvNd
from .tucker_conv_base import _pair


class Conv2d_tucker_fixed(_ConvNd):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size,  #: torch.nn.common_types._size_2_t,
            stride,  #: torch.nn.common_types._size_2_t = 1,
            padding,  #: str | torch.nn.common_types._size_2_t = 0,
            dilation,  #: torch.nn.common_types._size_2_t = 1,
            groups: int = 1,
            bias: bool = True,
            padding_mode: str = "zeros",
            dtype=None,
            device=None,
            low_rank_percent=None,
            # convert_from_weights: torch.torch.Tensor = None,
            # existing_bias: torch.Tensor = None,
    ) -> None:
        """
        Initializer for the convolutional low rank layer (filterwise), extention of the classical Pytorch's convolutional layer.
        INPUTS:
        in_channels: number of input channels (Pytorch's standard)
        out_channels: number of output channels (Pytorch's standard)
        kernel_size : kernel_size for the convolutional filter (Pytorch's standard)
        dilation : dilation of the convolution (Pytorch's standard)
        padding : padding of the convolution (Pytorch's standard)
        stride : stride of the filter (Pytorch's standard)
        bias  : flag variable for the bias to be included (Pytorch's standard)
        step : string variable ('K','L' or 'S') for which forward phase to use
        rank : rank variable, None if the layer has to be treated as a classical Pytorch Linear layer (with weight and bias). If
                it is an int then it's either the starting rank for adaptive or the fixed rank for the layer.
        fixed : flag variable, True if the rank has to be fixed (KLS training on this layer)
        load_weights : variables to load (Pytorch standard, to finish)
        dtype : Type of the tensors (Pytorch standard, to finish)
        """
        # TODO: fix init
        #   TODO: maybe remove this and simply use adaptive instead but just dont call the adapt
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            transposed=False,
            output_padding=_pair(0),
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
            device=device,
            dtype=dtype,
            # ====================== DLRT params =========================================
            low_rank_percent=low_rank_percent,
            fixed_rank=True,
        )

        self.device, self.dtype = device, dtype
        # self.root_module = True  # to update the step

        self.dims = [self.out_channels, self.in_channels] + list(self.kernel_size)

        # make sure that there are at least 3 channels, for rgb images
        self.rmax = [max(int(d * (1.-low_rank_percent)), 3) for d in self.dims[:2]] + self.dims[2::]
        self.rank = self.rmax
        self.dynamic_rank = self.rank
        #### update starting ranks to satisfy the compatibility condition ri<=prod(r-i)
        self.rmax = [min(d, self.get_r_mod_i(i)) for i, d in enumerate(self.rmax)]  #### new
        self.dynamic_rank = self.rmax  ###### new
        ####
        self.fixed = True

        if self.bias is not None:
            self.bias = torch.nn.Parameter(
                torch.zeros(self.out_channels, requires_grad=False, **factory_kwargs),
                requires_grad=False,
            )

        self.C = torch.nn.Parameter(torch.empty(size=[s for s in self.dynamic_rank]))
        self.Us = [torch.empty(size=(d, r)) for d, r in zip(self.dims, self.rmax)]
        self.Us = torch.nn.ParameterList(
            [torch.nn.Parameter(U[:, :r], requires_grad=False).to(device) for r, U in zip(self.rmax, self.Us)])

        self.Ks = torch.nn.ParameterList(
            [torch.nn.Parameter(torch.empty(d, r)) for (d, r) in zip(self.dims, self.rmax)])

        self.Qst = torch.nn.ParameterList()
        for i, r in enumerate(self.rmax):
            other_dims = [d for j, d in enumerate(self.rmax) if j != i]
            self.Qst.append(torch.nn.Parameter(torch.empty(r, *other_dims), requires_grad=False))

        self.M_hats = torch.nn.ParameterList(
            [torch.nn.Parameter(torch.empty(r, r), requires_grad=False) for r in self.rmax])
        # todo: convert from full rank
        self.reset_parameters()

    @torch.no_grad()
    def reset_parameters(self):
        # torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        torch.nn.init.kaiming_uniform_(self.C, a=math.sqrt(5))
        # print(self.C)
        # input()
        for i in range(len(self.dims)):
            torch.nn.init.kaiming_uniform_(self.Us[i], a=math.sqrt(5))
            torch.nn.init.kaiming_uniform_(self.Qst[i], a=math.sqrt(5))
            torch.nn.init.kaiming_uniform_(self.Ks[i], a=math.sqrt(5))
            torch.nn.init.kaiming_uniform_(self.M_hats[i], a=math.sqrt(5))

        if self.bias is not None:
            weight = torch.empty(
                (self.out_channels, self.in_channels // self.groups, *self.kernel_size),
                device=self.device,
                dtype=self.dtype,
            )
            torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(weight)
            if fan_in != 0:
                bound = 1 / math.sqrt(fan_in)
                torch.nn.init.uniform_(self.bias, -bound, bound)
            del weight

        self.basic_number_weights = self.in_channels * (self.out_channels // self.groups + self.kernel_size_number)


    def forward(self, input):
        """
        forward phase for the convolutional layer. It has to contain the three different
        phases for the steps 'K','L' and 'S' in order to be optimizable using dlrt.
        Every step is rewritten in terms of the tucker decomposition of the kernel tensor
        """

        if self.step == 1:  # first mode

            r = self.dynamic_rank[0]
            Us = [u[:, :self.dynamic_rank[j]] for j, u in enumerate(self.Us) if j != 0]
            other_ranks = [self.dynamic_rank[j] for j, u in enumerate(self.Us) if j != 0]
            Qst = self.Qst[0][:r, :other_ranks[0], :other_ranks[1], :other_ranks[2]]
            Us = [torch.eye(n=Qst.shape[0],device = Qst.device)] + Us
            first_conv = tltorch.functional.tucker_conv(input, tucker_tensor=tltorch.TuckerTensor(Qst, Us,
                                                                                                  rank=self.dynamic_rank),
                                                        stride=self.stride, padding=self.padding,
                                                        dilation=self.dilation)
            result = torch.einsum('fl,nluv->nfuv', self.Ks[0][:, :r], first_conv)
            if self.bias is not None:
                result += self.bias.view(self.bias.shape[0], 1, 1)

            return result


        elif self.step == 2:     # second mode

            r = self.dynamic_rank[1]
            Us = [u[:, :self.dynamic_rank[j]] for j, u in enumerate(self.Us) if j != 1]
            other_ranks = [self.dynamic_rank[j] for j, u in enumerate(self.Us) if j != 1]
            result = torch.einsum('cl,ncxy->nlxy', self.Ks[1][:, :r], input)
            Qst = self.Qst[1][:r, :other_ranks[0], :other_ranks[1], :other_ranks[2]]
            Qst = rearrange(Qst, 'l f j k -> f l j k')
            Us.insert(1, torch.eye(n=Qst.shape[1],device = Qst.device))
            result = tltorch.functional.tucker_conv(result,
                                                    tucker_tensor=tltorch.TuckerTensor(Qst, Us, rank=self.dynamic_rank),
                                                    stride=self.stride, padding=self.padding, dilation=self.dilation,
                                                    bias=self.bias)

            return result


        elif self.step == 3:     # third mode

            r = self.dynamic_rank[2]
            Us = [u[:, :self.dynamic_rank[j]] for j, u in enumerate(self.Us) if j != 2]
            other_ranks = [self.dynamic_rank[j] for j, u in enumerate(self.Us) if j != 2]
            Qst = self.Qst[2][:r, :other_ranks[0], :other_ranks[1], :other_ranks[2]]
            Qst = rearrange(Qst, 'j f c k -> f c j k')
            Us.insert(2, self.Ks[2][:, :self.dynamic_rank[2]])
            result = tltorch.functional.tucker_conv(input,
                                                    tucker_tensor=tltorch.TuckerTensor(Qst, Us, rank=self.dynamic_rank),
                                                    stride=self.stride, padding=self.padding, dilation=self.dilation,
                                                    bias=self.bias)

            return result


        elif self.step == 4:    # fourth mode

            r = self.dynamic_rank[3]
            Us = [u[:, :self.dynamic_rank[j]] for j, u in enumerate(self.Us) if j != 3]
            other_ranks = [self.dynamic_rank[j] for j, u in enumerate(self.Us) if j != 3]
            Qst = self.Qst[3][:r, :other_ranks[0], :other_ranks[1], :other_ranks[2]]
            Qst = rearrange(Qst, 'j f c k -> f c k j')
            Us.insert(3, self.Ks[3][:, :self.dynamic_rank[3]])
            result = tltorch.functional.tucker_conv(input,
                                                    tucker_tensor=tltorch.TuckerTensor(Qst, Us, rank=self.dynamic_rank),
                                                    stride=self.stride, padding=self.padding, dilation=self.dilation,
                                                    bias=self.bias)
            return result


        elif self.step == 'core':     # core step

            C = self.C[:self.dynamic_rank[0], :self.dynamic_rank[1], :self.dynamic_rank[2], :self.dynamic_rank[3]]
            Us = [U[:, :self.dynamic_rank[i]] for i, U in enumerate(self.Us)]

            result = tltorch.functional.tucker_conv(input,
                                                    tucker_tensor=tltorch.TuckerTensor(C, Us, rank=self.dynamic_rank),
                                                    bias=self.bias, stride=self.stride, padding=self.padding,
                                                    dilation=self.dilation)

            return result

        elif self.step == 'test':  # step for testing

            C = self.C[:self.dynamic_rank[0], :self.dynamic_rank[1], :self.dynamic_rank[2], :self.dynamic_rank[3]]
            Us = [U[:, :self.dynamic_rank[i]] for i, U in enumerate(self.Us)]

            result = tltorch.functional.tucker_conv(input,
                                                    tucker_tensor=tltorch.TuckerTensor(C, Us, rank=self.dynamic_rank),
                                                    bias=self.bias, stride=self.stride, padding=self.padding,
                                                    dilation=self.dilation)
            return result

        elif self.step == None:

            return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation)

        else:

            raise ValueError(f'incorrect step value {self.step}')

    @torch.no_grad()
    def update_step(self, step):
        '''
        update the forward step flag for the layer
        '''
        self.step = step

    @torch.no_grad()
    def K_preprocess_step(self):
        '''
        sets initial conditions for the integration of the Ks ODEs
        '''

        r1, r2, r3, r4 = self.dynamic_rank

        for i in range(len(self.C.shape)):
            MAT_i_C = tly.base.unfold(self.C[:r1, :r2, :r3, :r4], mode=i)
            # print(f'Mat_i_C.shape mode:{i} shape {MAT_i_C.shape}')
            _, S_i_0_T = torch.linalg.qr(MAT_i_C.T)
            S_i_0_T = S_i_0_T[:self.dynamic_rank[i], :self.dynamic_rank[i]]
            Ks = self.Us[i][:, :self.dynamic_rank[i]] @ S_i_0_T.T
            self.Ks[i][:, :self.dynamic_rank[i]] = Ks

    @torch.no_grad()
    def L_preprocess_step(self):

        '''
        sets initial conditions for the integration of the L ODE (not present but needed for coherence with linear layers)
        '''

        pass

    @torch.no_grad()
    def S_preprocess_step(self):

        '''
        sets initial conditions for the integration of the core ODE
        '''

        low_rank_core = self.C[:self.dynamic_rank[0], :self.dynamic_rank[1], :self.dynamic_rank[2],
                        :self.dynamic_rank[3]]
        self.C[:self.dynamic_rank[0], :self.dynamic_rank[1], :self.dynamic_rank[2],
        :self.dynamic_rank[3]] = torch.einsum('abcd,ia,jb,kc,ld->ijkl',
                                              [low_rank_core] + [M[:self.dynamic_rank[i], :self.dynamic_rank[i]] \
                                                                 for i, M in enumerate(self.M_hats)])

    @torch.no_grad()
    def K_postprocess_step(self):

        '''
        Compute new basis corresponding to the Ks for the unconventional integrator and stiffness matrices Ms
        '''

        for i in range(len(self.C.shape)):

            U_hat = torch.hstack((self.Ks[i][:, :self.dynamic_rank[i]], self.Us[i][:, :self.dynamic_rank[i]]))

            try:
                U_hat, _ = torch.linalg.qr(U_hat)
            except:
                U_hat, _ = np.linalg.qr(U_hat)
                U_hat = torch.tensor(U_hat)
            self.M_hats[i][:self.dynamic_rank[i], :self.dynamic_rank[i]] = U_hat[:, :self.dynamic_rank[i]].T @ self.Us[i][:,:self.dynamic_rank[i]]
            self.Us[i][:, :self.dynamic_rank[i]] = U_hat[:, :self.dynamic_rank[i]]

    @torch.no_grad()
    def L_postprocess_step(self):

        '''
        Compute new basis for the unconventional integrator and stiffness matrix N (not present, kept just for coherency with the linear)
        '''

        pass

    @torch.no_grad()
    def S_postprocess_step(self):

        '''
        No rank adaption for the fixed rank unconventional integrator
        '''

        pass

    @torch.no_grad()
    def update_Q(self):

        '''
        update right core for the K steps
        '''

        C = self.C[:self.dynamic_rank[0], :self.dynamic_rank[1], :self.dynamic_rank[2], :self.dynamic_rank[3]]
        for i in range(len(self.dims)):
            Mat_i_C = tly.unfold(C, mode=i).T
            Q, _ = torch.linalg.qr(Mat_i_C)
            other_ranks = [self.dynamic_rank[j] for j, u in enumerate(self.Us) if j != i]
            Q_ten = tly.fold(Q, mode=i, shape=[self.dynamic_rank[i]] + other_ranks)
            self.Qst[i] = Q_ten

    @torch.no_grad()
    def get_r_mod_i(self, i):

        return min(self.dynamic_rank[i], math.prod([r for j, r in enumerate(self.dynamic_rank) if j != i]))
