import torch
import copy
from define_devices import my_device, my_device_1


#####################################################
################## Solving EB #######################


def compute_tensor_s_for_conv_to_lin(masked_unfolded_b: torch.Tensor) -> torch.Tensor:
    """
    Compute the matrix S of proposition 3.2 when adding neurons
    between a convolution and a linear layer

    Return
    R^(j k) =  M^(n i j) M_(n i k) / n
    """
    return torch.einsum('ila, ilb -> ab', masked_unfolded_b, masked_unfolded_b)


def compute_tensor_n_for_conv_to_lin(dv: torch.Tensor, masked_unfolded_b: torch.Tensor) -> torch.Tensor:
    """
    Compute the matrix N of proposition 3.2 when adding neurons
    between a convolution and a linear layer
    """
    return torch.einsum('ila, ic -> alc', masked_unfolded_b, dv).flatten(start_dim=1)


def compute_matrix_n_for_conv_to_conv(masked_unfolded_b: torch.Tensor,
                                      tensor_t: torch.Tensor,
                                      dv: torch.Tensor
                                      ) -> torch.Tensor:
    """
    Compute the matrix N of proposition 3.2 when adding neurons
    between two convolutional layers.
    """
    dv = dv.flatten(start_dim=2)
    assert tensor_t.shape[0] == dv.shape[2], \
        f"{tensor_t.shape[0]=} and {dv.shape[2]=} should be equal to H'[+1]W'[+1]"
    assert tensor_t.shape[2] == masked_unfolded_b.shape[1], \
        f"{tensor_t.shape[2]=} and {masked_unfolded_b.shape[1]=} should be equal to HW"
    assert masked_unfolded_b.shape[0] == dv.shape[0], \
        f"{masked_unfolded_b.shape[0]=} and {dv.shape[0]=} should be equal to n (batch size)"

    return torch.einsum('xkl,ila,icx->ack', tensor_t.to_dense(), masked_unfolded_b, dv).flatten(start_dim=1)


def compute_matrix_s_for_conv_to_conv(masked_unfolded_b: torch.Tensor,
                                      tt: torch.Tensor
                                      ) -> torch.Tensor:
    """
    Compute the matrix S of proposition 3.2 when adding neurons
    between two convolutional layers.
    """
    assert tt.shape[0] == masked_unfolded_b.shape[1], \
        f"{tt.shape[0]=} and {masked_unfolded_b.shape[1]=} should be equal to HW"
    return torch.einsum('ila,lm,imb->ab', masked_unfolded_b, tt.to_dense(), masked_unfolded_b)


#####################################################    
########## Constuction of T, F and S ################


def compute_output_shape_conv(input_shape: tuple[int, int],
                              conv: torch.nn.Conv2d
                              ) -> tuple[int, int]:
    """
    Compute the output shape of a convolutional layer

    Parameters
    ----------
    input_shape: tuple
        shape of the input tensor (H, W)
    conv: torch.nn.Conv2d
        convolutional layer

    Returns
    -------
    tuple[int, int]
        output shape of the convolutional layer
    """
    h, w = input_shape
    h = (h + 2 * conv.padding[0] - conv.dilation[0] * (conv.kernel_size[0] - 1) - 1) // conv.stride[0] + 1
    w = (w + 2 * conv.padding[1] - conv.dilation[1] * (conv.kernel_size[1] - 1) - 1) // conv.stride[1] + 1

    with torch.no_grad():
        out_shape = conv(torch.empty((1, conv.in_channels, input_shape[0], input_shape[1]),
                         device=conv.weight.device)).shape[2:]

    assert h == out_shape[0], f"{h=} {out_shape[0]=} should be equal"
    assert w == out_shape[1], f"{w=} {out_shape[1]=} should be equal"

    return h, w


def compute_mask_tensor_t(input_shape: tuple[int, int],
                          conv: torch.nn.Conv2d
                          ) -> torch.Tensor:
    """
    Compute the tensor T
    For:
    - input tensor: B[-1] in (S[-1], H[-1]W[-1]) and (S[-1], H'[-1]W'[-1]) after the pooling
    - output tensor: B in (S, HW)
    - conv kernel tensor: W in (S, S[-1], Hd, Wd)
    T is the tensor in (HW, HdWd, H'[-1]W'[-1]) such that:
    B = W T B[-1]

    Parameters
    ----------
    input_shape: tuple
        shape of the input tensor B[-1] of size (H[-1], W[-1])
    conv: torch.nn.Conv2d
        convolutional layer applied to the input tensor B[-1]

    Returns
    -------
    tensor_t: torch.Tensor
        tensor T in (HW, HdWd, H[-1]W[-1])
    """
    h, w = compute_output_shape_conv(input_shape, conv)

    tensor_t = torch.zeros(
        (h * w, conv.kernel_size[0] * conv.kernel_size[1], input_shape[0] * input_shape[1]))
    unfold = torch.nn.Unfold(kernel_size=conv.kernel_size, padding=conv.padding, stride=conv.stride,
                             dilation=conv.dilation)
    t_info = unfold(
        torch.arange(1, input_shape[0] * input_shape[1] + 1).float().reshape((1, input_shape[0], input_shape[1]))).int()
    for lc in range(h * w):
        for k in range(conv.kernel_size[0] * conv.kernel_size[1]):
            if t_info[k, lc] > 0:
                tensor_t[lc, k, t_info[k, lc] - 1] = 1
    return tensor_t


def compute_mask_tensor_t_and_tt(model: 'TINY', depth: int) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute the tensor T and the tensor TT at a given depth

    Parameters
    ----------
    model: TINY
        model, we use the following attributes:
            - model.layer[depth]
            - model.fct[depth]
            - model.outputs_size_after_activation[depth - 1]
    depth: int
        depth at which we compute the tensor T and the tensor TT

    Returns
    -------
    tensor_t: torch.Tensor
        tensor T in (HW, HdWd, H'[-1]W'[-1])
        where H'[-1]W'[-1] is the size of the input tensor AFTER THE POOLING
    tensor_tt: torch.Tensor
        tensor TT in (HW, HW)
    """
    h, w = model.outputs_size_after_activation[depth - 1][:2]
    tensor_t = compute_mask_tensor_t((h, w), model.layer[depth]['C'])
    tensor_tt = torch.einsum('xkl, xkm->lm', tensor_t, tensor_t)
    return tensor_t.to(my_device_1), tensor_tt.to(my_device_1)


def creation_T_C_pour_BCR(model, depth):
    if hasattr(model.fct[depth], 'kernel_size'):  # if the layer is a MaxPool

        T_tot = torch.tensor([], device=my_device)

        input_shape = model.outputs_size_after_activation[depth - 1][:2]
        a, b = compute_output_shape_conv(input_shape, model.layer[depth]['C'])
        T_0 = torch.zeros((a * b), device=my_device)

        for j in range(2):  # model.fct[depth].kernel_size
            T_0[j * a: j * a + 2] = 1. * torch.ones(2, device=my_device)

        if int(a * b / 4.) * 4 == a * b:
            for i in range(1, int((a) * (b) / 4.) + 1):
                T_tot = torch.cat([T_tot, copy.deepcopy(torch.unsqueeze(T_0, dim=0))], dim=0)
                if int((2 * i) / a) * a == 2 * i:
                    T_0 = torch.cat([T_0[-(2 + a):], T_0[:-(2 + a)]], dim=0)
                else:
                    T_0 = torch.cat([T_0[-2:], T_0[:-2]])
        else:

            for i in range(1, int((a - 1) * (b - 1) / 4) + 1):
                T_tot = torch.cat([T_tot, copy.deepcopy(torch.unsqueeze(T_0, dim=0))], dim=0)
                if int((2 * i) / (a - 1)) * (a - 1) == 2 * i:
                    T_0 = torch.cat([T_0[-(2 + a + 1):], T_0[:-(2 + a + 1)]], dim=0)
                else:
                    T_0 = torch.cat([T_0[-2:], T_0[:-2]])

        if isinstance(model.fct[depth].kernel_size, int):
            kernel_size = model.fct[depth].kernel_size ** 2
        elif isinstance(model.fct[depth].kernel_size, tuple):
            assert model.fct[depth].kernel_size[0] == model.fct[depth].kernel_size[1], \
                f"MaxPool kernel size should be a square, got {model.fct[depth].kernel_size}"\
                f" at depth {depth} (this may be not a problem and this could be considered as a warning)"
            kernel_size = model.fct[depth].kernel_size[0] * model.fct[depth].kernel_size[1]
        else:
            raise ValueError(
                f"MaxPool kernel size should be an int or a tuple[int, int], got {model.fct[depth].kernel_size}")
        return (T_tot / kernel_size).to_sparse().to(my_device_1)
