import sys
import numpy as np

sys.path.append("../")
from neural_networks.base_model import BaseModel


class BaseFusion:
    def __init__(self, eps=10 ** -8, act=False, fix_mu=False, combine_costs=False):
        self.A1_list = None
        self.A2_list = None
        self.A3_list = None
        self.eps = eps
        self.act = act
        self.kernel_forward = []
        self.kernel_backward = []
        self.mfels = []
        self.miels = []
        self.nfels = []
        self.niels = []
        self.lambdas_l = []
        self.fix_mu = fix_mu
        self.combine_costs = combine_costs

    def reset(self):
        self.A1_list = None
        self.A2_list = None
        self.A3_list = None

    def combine_layers(self, model1, model2, data, lambdas, layer, sigmas=None, out_ens=False, reverse=False, pgd=False):
        if self.A1_list is None or self.A2_list is None or self.A3_list is None:
            A1, A2, A3 = self.fuse_two_models_partial([model1, model2], data=data, lambdas=lambdas,
                                                      sigmas=sigmas, out_ens=out_ens, reverse=reverse, pgd=pgd)
            self.A1_list = A1
            self.A2_list = A2
            self.A3_list = A3
        return self.A1_list[layer], self.A2_list[layer], self.A3_list[layer]

    def initialize_fusion(self, models: [BaseModel], lambdas=None, sigmas=None, reverse=False, pgd=False):
        K = len(models)
        layers1 = [layer for layer in models[0].get_layer_names_with_weights()
                   if layer not in models[0].get_residual_layers()[0]]
        layers2 = [layer for layer in models[1].get_layer_names_with_weights()
                   if layer not in models[1].get_residual_layers()[0]]
        L = len(layers1)
        if reverse:
            layers1.reverse()
            layers2.reverse()
            init_dimension = models[0].get_incoming_weights(layers1[0], numpy=True).shape[0]
        else:
            init_dimension = models[0].get_incoming_weights(layers1[0], numpy=True).shape[1]
        if lambdas is None:
            inner_list = [1.0 / K] * K
            lambdas = [inner_list for _ in range(L)]
            lambdas_lenght = L
        elif not isinstance(lambdas[0], list):
            lambdas = [lambdas]
            lambdas_lenght = 1
        else:
            lambdas_lenght = len(lambdas)
        if sigmas is None:
            zipped_lists = list(zip(lambdas[-1], models))
            zipped_lists.sort(key=lambda x: x[0])
            sorted_lambdas, sorted_models = zip(*zipped_lists)
            lambdas[-1] = list(sorted_lambdas)
            models = list(sorted_models)
        else:
            zipped_lists = list(zip(lambdas[-1], models, sigmas))
            zipped_lists.sort(key=lambda x: x[0])
            sorted_lambdas, sorted_models, sorted_sigmas = zip(*zipped_lists)
            lambdas[-1] = list(sorted_lambdas)
            models = list(sorted_models)
            sigmas = list(sorted_sigmas)

        if not pgd:
            # initial values are for the input layer, by definition co-monotone coupling
            mu_f_list = [np.ones(init_dimension) / init_dimension]
            nu_f_list = [np.ones(init_dimension) / init_dimension]
            mu_i_list = [np.zeros(init_dimension)]
            nu_i_list = [np.zeros(init_dimension)]
            kernels_forward = [np.identity(init_dimension)]
            kernels_backward = [np.identity(init_dimension)]
        else:
            # initialization of all kernels and all lists
            kernels_forward = []
            kernels_backward = []
            mu_f_list = []
            nu_f_list = []
            mu_i_list = []
            nu_i_list = []
            layers1_no_norm = []
            layers2_no_norm = []
            for l in range(L):
                w = self.get_support(models[0], layers1[l])
                if len(w.shape) == 1:
                    continue
                else:
                    layers1_no_norm.append(layers1[l])
                    layers2_no_norm.append(layers2[l])
                dim = w.shape[1]
                kernels_forward.append(np.identity(dim))
                kernels_backward.append(np.identity(dim))
                mu_f_list.append(np.ones(dim) / dim)
                nu_f_list.append(np.ones(dim) / dim)
                mu_i_list.append(np.zeros(dim) / dim)
                nu_i_list.append(np.zeros(dim) / dim)
            layers1 = layers1_no_norm
            layers2 = layers2_no_norm
            L = len(layers1)
            # append last layer
            w = self.get_support(models[0], layers1[-1])
            dim = w.shape[0]
            kernels_forward.append(np.identity(dim))
            kernels_backward.append(np.identity(dim))
            mu_f_list.append(np.ones(dim) / dim)
            nu_f_list.append(np.ones(dim) / dim)
            mu_i_list.append(np.zeros(dim) / dim)
            nu_i_list.append(np.zeros(dim) / dim)
        return (models, layers1, layers2, L, mu_f_list, nu_f_list, mu_i_list, nu_i_list,
                kernels_forward, kernels_backward, lambdas, lambdas_lenght, sigmas)

    def align(self, mu, nu, mu_f, nu_f, ker_adj):
        #CNN alignment
        if len(mu.shape) == 4:
            mu_f = np.array(mu_f).flatten().tolist()
            nu_f = np.array(nu_f).flatten().tolist()
            if len(mu_f) == mu.shape[1]:
                mu = mu[:, mu_f, :, :]
                nu = nu[:, nu_f, :, :]
                mu = np.einsum('bchw,cd->bdhw', mu, ker_adj)  # mu @ k_adjust
                mu = mu.reshape(mu.shape[0], mu.shape[1] * mu.shape[2] * mu.shape[3])
                nu = nu.reshape(nu.shape[0], nu.shape[1] * nu.shape[2] * nu.shape[3])
            else:
                mu = mu.reshape(mu.shape[0], mu.shape[1] * mu.shape[2] * mu.shape[3])
                nu = nu.reshape(nu.shape[0], nu.shape[1] * nu.shape[2] * nu.shape[3])
                mu = mu[:, mu_f]
                nu = nu[:, nu_f]
                mu = mu @ ker_adj
            return mu, nu

        if mu_f.shape[0] != mu.shape[1]:
            # CNN to linear layer
            k = int(mu.shape[1] / mu_f.shape[0])
            mu_f = [x for x in mu_f for _ in range(k)]
            nu_f = [x for x in nu_f for _ in range(k)]
            n = ker_adj.shape[0]
            I = np.eye(k)
            ker_back_extended = ker_adj[:, :, None, None] * I[None, None, :, :]

            # Rearrange to shape (n*k, n*k)
            ker_back_extended = ker_back_extended.transpose(0, 2, 1, 3).reshape(n * k, n * k)
            ker_adj = ker_back_extended

        mu_f = np.array(mu_f).flatten().tolist()
        nu_f = np.array(nu_f).flatten().tolist()
        mu = mu[:, mu_f] @ ker_adj
        nu = nu[:, nu_f]
        return mu, nu

    def compute_kernels_pcd(self, models, layers1, layers2, L, mu_f_list, nu_f_list, mu_i_list, nu_i_list,
                            kernels_forward, kernels_backward, iter=10):
        for i in range(iter):
            for l in range(L-1): # lineare reihenfolge
                mu_for = self.get_support(models[0], layers1[l])
                nu_for = self.get_support(models[1], layers2[l])
                mu_f = mu_f_list[l] > 10 ** -8
                nu_f = nu_f_list[l] > 10 ** -8
                mu_for, nu_for = self.align(mu_for, nu_for, mu_f, nu_f, kernels_backward[l])
                mu_back = self.get_support(models[0], layers1[l+1], reverse=True)
                nu_back = self.get_support(models[1], layers2[l+1], reverse=True)
                mu_f = mu_f_list[l+2] > 10 ** -8
                nu_f = nu_f_list[l+2] > 10 ** -8
                mu_back, nu_back = self.align(mu_back, nu_back, mu_f, nu_f, kernels_backward[l+2])
                if mu_for.shape[0] != mu_back.shape[0]:
                    k = mu_back.shape[0] // mu_for.shape[0]
                    mu_back_new = np.zeros((mu_for.shape[0], mu_back.shape[1]))
                    nu_back_new = np.zeros((nu_for.shape[0], nu_back.shape[1]))
                    for j in range(mu_for.shape[0]):
                        mu_back_new[j] = np.mean(mu_back[j * k:(j + 1) * k, :], axis=0)
                        nu_back_new[j] = np.mean(nu_back[j * k:(j + 1) * k, :], axis=0)
                    mu_back = mu_back_new
                    nu_back = nu_back_new
                mu = [mu_for, mu_back]
                nu = [nu_for, nu_back]
                mu_fuse, mu_iso, nu_fuse, nu_iso, k_for, k_back = self.get_mapping(mu, nu)
                kernels_forward[l+1] = k_for
                kernels_backward[l+1] = k_back
                mu_i_list[l+1] = mu_iso[:, None]
                nu_i_list[l+1] = nu_iso[:, None].copy()
                mu_f_list[l+1] = mu_fuse[:, None].copy()
                nu_f_list[l+1] = nu_fuse[:, None].copy()
        self.kernel_forward = kernels_forward
        self.kernel_backward = kernels_backward
        return mu_f_list, nu_f_list, mu_i_list, nu_i_list, kernels_forward, kernels_backward


    def compute_kernels(self, models, layers1, layers2, L, mu_f_list, nu_f_list, mu_i_list, nu_i_list, kernels_forward,
                        kernels_backward, data=None, out_ens=False, reverse=False, pgd=False):
        if pgd:
            return self.compute_kernels_pcd(models, layers1, layers2, L, mu_f_list, nu_f_list, mu_i_list, nu_i_list,
                                            kernels_forward, kernels_backward)
        for l in range(L):
            w_a = models[0].get_incoming_weights(layers1[l], numpy=True)
            l1 = layers1[l]
            l2 = layers2[l]
            if data is not None and reverse and l!=L-1 and not self.combine_costs:
                l1 = layers1[l+1]
                l2 = layers2[l+1]
            if len(w_a.shape) == 2 and l1:
                mu = self.get_support(models[0], l1, data=data, reverse=reverse)
                nu = self.get_support(models[1], l2, data=data, reverse=reverse)
                if data is None:
                    mu_f = mu_f_list[-1] > 10 ** -8
                    nu_f = nu_f_list[-1] > 10 ** -8
                    mu, nu = self.align(mu, nu, mu_f, nu_f, kernels_backward[-1])

                if self.combine_costs:
                    mu_f = mu_f_list[-1] > 10 ** -8
                    nu_f = nu_f_list[-1] > 10 ** -8
                    mu_w = self.get_support(models[0], l1, reverse=reverse)
                    nu_w = self.get_support(models[1], l2, reverse=reverse)
                    mu_w, nu_w = self.align(mu_w, nu_w, mu_f, nu_f, kernels_backward[-1])
                    mu = [mu, mu_w]
                    nu = [nu, nu_w]
                else:
                    mu = [mu]
                    nu = [nu]
                mu_fuse, mu_iso, nu_fuse, nu_iso, k_for, k_back = self.get_mapping(mu, nu)
                if data is not None and reverse and l == L - 1:
                    mu = data.view(data.size(0), -1).T.numpy()
                    mu_fuse = np.ones(mu.shape[0]) / mu.shape[0]
                    nu_fuse = mu_fuse
                    mu_iso = np.zeros(mu.shape[0])
                    nu_iso = mu_iso
                mu_i_list.append(mu_iso[:, None].copy())
                nu_i_list.append(nu_iso[:, None].copy())
                mu_f_list.append(mu_fuse[:, None].copy())
                nu_f_list.append(nu_fuse[:, None].copy())

                if l < L - 1:
                    kernels_forward.append(k_for.copy())
                    kernels_backward.append(k_back.copy())
                else:
                    # case l == L-1:
                    if not reverse or data is None:
                        mu = self.get_support(models[0], layers1[L - 1], data=data, reverse=reverse)
                        nu = self.get_support(models[1], layers2[L - 1], data=data, reverse=reverse)
                    if out_ens:
                        kernels_forward.append(np.identity(0 * len(mu)))
                        kernels_backward.append(np.identity(2 * len(nu)))
                    else:
                        kernels_forward.append(np.identity(len(mu)))
                        kernels_backward.append(np.identity(len(nu)))
            elif len(w_a.shape) == 4: #CNNs
                mu = self.get_support(models[0], layers1[l], data=data, reverse=reverse)
                nu = self.get_support(models[1], layers2[l], data=data, reverse=reverse)
                # Transpose to (B, C, H, W)
                if data is not None:
                    mu = mu.transpose(3, 2, 1, 0)
                    nu = nu.transpose(3, 2, 1, 0)
                    # Reshape to (C, B * H * W)
                    mu = mu.reshape(mu.shape[0], mu.shape[1], -1).transpose(1, 0, 2).reshape(mu.shape[1], -1)
                    nu = nu.reshape(nu.shape[0], nu.shape[1], -1).transpose(1, 0, 2).reshape(nu.shape[1], -1)
                else:
                    mu_f = mu_f_list[-1] > 10 ** -8
                    nu_f = nu_f_list[-1] > 10 ** -8
                    ker_adj = kernels_backward[-1]
                    mu_f = np.array(mu_f).flatten().tolist()
                    nu_f = np.array(nu_f).flatten().tolist()
                    mu, nu = self.align(mu, nu, mu_f, nu_f, ker_adj)
                mu = [mu]
                nu = [nu]
                if self.combine_costs:
                    mu_w = self.get_support(models[0], layers1[l])
                    nu_w = self.get_support(models[1], layers2[l])
                    mu_w = mu_w.reshape(mu_w.shape[0], mu_w.shape[1] * mu_w.shape[2] * mu_w.shape[3])
                    nu_w = nu_w.reshape(nu_w.shape[0], nu_w.shape[1] * nu_w.shape[2] * nu_w.shape[3])
                    mu.append(mu_w)
                    nu.append(nu_w)
                mu_fuse, mu_iso, nu_fuse, nu_iso, k_for, k_back = self.get_mapping(mu, nu)
                mu_i_list.append(mu_iso[:, None].copy())
                nu_i_list.append(nu_iso[:, None].copy())
                mu_f_list.append(mu_fuse[:, None].copy())
                nu_f_list.append(nu_fuse[:, None].copy())

                if l < L - 1:
                    kernels_forward.append(k_for.copy())
                    kernels_backward.append(k_back.copy())
                else:
                    # case l == L-1:
                    mu = self.get_support(models[0], layers1[L - 1], data=data, reverse=reverse)
                    nu = self.get_support(models[1], layers2[L - 1], data=data, reverse=reverse)
                    if out_ens:
                        kernels_forward.append(np.identity(0*len(mu)))
                        kernels_backward.append(np.identity(2*len(nu)))
                    else:
                        kernels_forward.append(np.identity(len(mu)))
                        kernels_backward.append(np.identity(len(nu)))

        self.kernel_forward = kernels_forward
        self.kernel_backward = kernels_backward
        return mu_f_list, nu_f_list, mu_i_list, nu_i_list, kernels_forward, kernels_backward


    def fuse_two_models_partial(self, models: [BaseModel], data=None, lambdas=None, sigmas=None, out_ens=False, reverse=False, pgd=False):
        # partial fusion of two models with partial OT.
        # returns a priori a sequence of matrices W^l_i, for l indexing the layer and i=0, 1, 2, where i=0 is the
        # fused model part, i=1 is the isolated part of the first model and i=2 is the isolated part of the second model
        (models, layers1, layers2, L, mu_f_list, nu_f_list, mu_i_list, nu_i_list, kernels_forward, kernels_backward,
         lambdas, lambdas_length, sigmas) = self.initialize_fusion(models, lambdas=lambdas, sigmas=sigmas, reverse=reverse, pgd=pgd)
        mu_f_list, nu_f_list, mu_i_list, nu_i_list, kernels_forward, kernels_backward \
            = self.compute_kernels(models, layers1, layers2, L, mu_f_list, nu_f_list, mu_i_list, nu_i_list,
                                   kernels_forward, kernels_backward, data=data, out_ens=out_ens, reverse=reverse, pgd=pgd)
        # recompute as they might have changed with pgd
        layers1 = [layer for layer in models[0].get_layer_names_with_weights()
                   if layer not in models[0].get_residual_layers()[0]]
        layers2 = [layer for layer in models[1].get_layer_names_with_weights()
                   if layer not in models[1].get_residual_layers()[0]]
        L = len(layers1)
        A1_dict = {}
        A2_dict = {}
        A3_dict = {}
        # a counter for actual linear layers
        l = 0

        for l_all in range(L):
            alpha_l = lambdas[l % lambdas_length]
            self.lambdas_l.append(alpha_l)
            w_a = models[0].get_incoming_weights(layers1[l_all], numpy=True)
            w_b = models[1].get_incoming_weights(layers2[l_all], numpy=True)
            if reverse:
                w_a = w_a.T
                w_b = w_b.T
            mfel = mu_f_list[l] > self.eps
            mfelp = mu_f_list[l + 1] > self.eps
            miel = mu_i_list[l] > self.eps
            mielp = mu_i_list[l + 1] > self.eps
            nfel = nu_f_list[l] > self.eps
            nfelp = nu_f_list[l + 1] > self.eps
            niel = nu_i_list[l] > self.eps
            nielp = nu_i_list[l + 1] > self.eps

            # the below is to split weights in case neurons contribute to both the fused and isolated parts...
            # Note that we always handle such cases based on support for the layer where the weights are incoming.
            mwfp = mu_f_list[l + 1][mfelp, None].astype(float) / (mu_f_list[l + 1][mfelp, None].astype(float) + mu_i_list[l + 1][mfelp, None].astype(float))
            mwip = mu_i_list[l + 1][mielp, None].astype(float) / (mu_i_list[l + 1][mielp, None].astype(float) + mu_f_list[l + 1][mielp, None].astype(float))
            nwfp = nu_f_list[l + 1][nfelp, None].astype(float) / (nu_f_list[l + 1][nfelp, None].astype(float) + nu_i_list[l + 1][nfelp, None].astype(float))
            nwip = nu_i_list[l + 1][nielp, None].astype(float) / (nu_i_list[l + 1][nielp, None].astype(float) + nu_f_list[l + 1][nielp, None].astype(float))

            mfelp = np.array(mfelp).flatten()
            mielp = np.array(mielp).flatten()
            nfelp = np.array(nfelp).flatten()
            nielp = np.array(nielp).flatten()
            mfel = np.array(mfel).flatten()
            miel = np.array(miel).flatten()
            nfel = np.array(nfel).flatten()
            niel = np.array(niel).flatten()

            if len(w_a.shape) == 2:
                # fusion of linear layers
                # get relevant decompositions of the weight matrices: (ff is fuse-fuse, fi is fuse-isolated, etc.)
                if l_all > 0 and l_all < L - 1:

                    if mfel.shape[0] != w_a.shape[1]:
                        k = int(w_a.shape[1] / mfel.shape[0])
                        mfel = [x for x in mfel for _ in range(k)]
                        miel = [x for x in miel for _ in range(k)]
                        nfel = [x for x in nfel for _ in range(k)]
                        niel = [x for x in niel for _ in range(k)]
                    w_a_ff = w_a[mfelp, :][:, mfel] * mwfp
                    w_a_fi = w_a[mielp, :][:, mfel] * mwip
                    w_a_if = w_a[mfelp, :][:, miel] * mwfp
                    w_a_ii = w_a[mielp, :][:, miel] * mwip

                    w_b_ff = w_b[nfelp, :][:, nfel] * nwfp
                    w_b_fi = w_b[nielp, :][:, nfel] * nwip
                    w_b_if = w_b[nfelp, :][:, niel] * nwfp
                    w_b_ii = w_b[nielp, :][:, niel] * nwip
                elif l_all == L - 1:

                    w_a_ff = w_a[:, :][:, mfel]
                    w_a_fi = w_a[0:0, :][:, mfel]  # 0
                    w_a_if = w_a[:, :][:, miel]
                    w_a_ii = w_a[0:0, :][:, miel]  # 0

                    w_b_ff = w_b[:, :][:, nfel]
                    w_b_fi = w_b[0:0, :][:, nfel]  # 0
                    w_b_if = w_b[:, :][:, niel]
                    w_b_ii = w_b[0:0, :][:, niel]  # 0

                    if out_ens:
                        w_a_ff = w_a[0:0, :][:, mfel] # 0
                        w_a_fi = w_a[:, :][:, mfel]
                        w_a_if = w_a[0:0, :][:, miel] # 0
                        w_a_ii = w_a[:, :][:, miel]

                        w_b_ff = w_b[0:0, :][:, nfel] # 0
                        w_b_fi = w_b[:, :][:, nfel]
                        w_b_if = w_b[0:0, :][:, niel] # 0
                        w_b_ii = w_b[:, :][:, niel]
                else:  # l == 1
                    w_a_ff = w_a[mfelp, :] * mwfp
                    w_a_fi = w_a[mielp, :] * mwip
                    w_a_if = w_a[mfelp, 0:0] * mwfp
                    w_a_ii = w_a[mielp, 0:0] * mwip

                    w_b_ff = w_b[nfelp, :] * nwfp
                    w_b_fi = w_b[nielp, :] * nwip
                    w_b_if = w_b[nfelp, 0:0] * nwfp
                    w_b_ii = w_b[nielp, 0:0] * nwip


                ker_for = kernels_forward[l + 1]
                ker_back = kernels_backward[l]
                l += 1

                if w_a_ff.shape[1] != ker_back.shape[0]:
                    k = int(w_a_ff.shape[1] / ker_back.shape[0])
                    n = ker_back.shape[0]
                    I = np.eye(k)
                    ker_back_extended = ker_back[:, :, None, None] * I[None, None, :, :]

                    # Rearrange to shape (n*k, n*k)
                    ker_back_extended = ker_back_extended.transpose(0, 2, 1, 3).reshape(n * k, n * k)
                    ker_back = ker_back_extended


                weighting_a = alpha_l[0]
                weighting_b = alpha_l[1]
                if self.fix_mu:
                    A3 = np.concatenate([w_b_if, w_b_ii], axis=0)
                else:
                    A3 = np.concatenate([weighting_b * w_b_if, w_b_ii], axis=0)
                A2 = np.concatenate([ker_for @ (weighting_a * w_a_if), w_a_ii], axis=0)
                A1 = np.concatenate(
                    [weighting_b * w_b_ff + ker_for @ (weighting_a * w_a_ff) @ ker_back,
                     w_a_fi @ ker_back, w_b_fi], axis=0)

                if reverse:
                    A1 = A1.T
                    A2 = A2.T
                    A3 = A3.T

                A1_dict[layers1[l_all]] = A1
                A2_dict[layers1[l_all]] = A2
                A3_dict[layers1[l_all]] = A3

                self.mfels.append(mfel)
                self.miels.append(miel)
                self.nfels.append(nfel)
                self.niels.append(niel)

            elif len(w_a.shape) == 4:
                # fusion of convolutional layers (O, I, H, W)
                if l_all > 0 and l_all < L - 1:
                    w_a_ff = w_a[mfelp][:, mfel] * mwfp[:, None, None]
                    w_a_fi = w_a[mielp][:, mfel] * mwip[:, None, None]
                    w_a_if = w_a[mfelp][:, miel] * mwfp[:, None, None]
                    w_a_ii = w_a[mielp][:, miel] * mwip[:, None, None]

                    w_b_ff = w_b[nfelp][:, nfel] * nwfp[:, None, None]
                    w_b_fi = w_b[nielp][:, nfel] * nwip[:, None, None]
                    w_b_if = w_b[nfelp][:, niel] * nwfp[:, None, None]
                    w_b_ii = w_b[nielp][:, niel] * nwip[:, None, None]

                elif l_all == L - 1:
                    w_a_ff = w_a[:, mfel]
                    w_a_fi = w_a[0:0, mfel]
                    w_a_if = w_a[:, miel]
                    w_a_ii = w_a[0:0, miel]

                    w_b_ff = w_b[:, nfel]
                    w_b_fi = w_b[0:0, nfel]
                    w_b_if = w_b[:, niel]
                    w_b_ii = w_b[0:0, niel]

                else:  # l == 0
                    w_a_ff = w_a[mfelp] * mwfp[:, None, None]
                    w_a_fi = w_a[mielp] * mwip[:, None, None]
                    w_a_if = w_a[mfelp, 0:0] * mwfp[:, None, None]
                    w_a_ii = w_a[mielp, 0:0] * mwip[:, None, None]

                    w_b_ff = w_b[nfelp] * nwfp[:, None, None]
                    w_b_fi = w_b[nielp] * nwip[:, None, None]
                    w_b_if = w_b[nfelp, 0:0] * nwfp[:, None, None]
                    w_b_ii = w_b[nielp, 0:0] * nwip[:, None, None]

                ker_for = kernels_forward[l + 1]
                ker_back = kernels_backward[l]
                l += 1
                if self.fix_mu:
                    A3 = np.concatenate([w_b_if, w_b_ii], axis=0)
                else:
                    A3 = np.concatenate([alpha_l[1] * w_b_if, w_b_ii], axis=0)
                A2 = np.concatenate([self.transform_conv_weights(alpha_l[0] * w_a_if, ker_for, np.eye(w_a_if.shape[1])), w_a_ii], axis=0)
                A1 = np.concatenate([
                    alpha_l[1] * w_b_ff +
                    self.transform_conv_weights(alpha_l[0] * w_a_ff, ker_for, ker_back),
                    self.transform_conv_weights(w_a_fi, np.eye(w_a_fi.shape[0]), ker_back),
                    w_b_fi
                ], axis=0)

                A1_dict[layers1[l_all]] = A1
                A2_dict[layers1[l_all]] = A2
                A3_dict[layers1[l_all]] = A3

                self.mfels.append(mfel)
                self.miels.append(miel)
                self.nfels.append(nfel)
                self.niels.append(niel)

        return A1_dict, A2_dict, A3_dict

    def transform_conv_weights(self, w, ker_for, ker_back):
        tmp = np.einsum('bchw,cd->bdhw', w, ker_back)
        out = np.tensordot(ker_for, tmp, axes=([1], [0]))  # (out_new, h, w, in_new)
        return out


    def get_support(self, model, layer_name, data=None, reverse=False):
        if data is not None:
            activations = model.get_activations(layer_name, data, numpy=True).T
            if self.act:
                func = model.get_next_activation(layer_name, numpy=True)
                if func is not None:
                    activations = func(activations)
            # shape: (neurons of layer by datapoints)
            return activations
        else:
            weights = model.get_incoming_weights(layer_name, numpy=True)
            if reverse:
                axes = list(range(weights.ndim))
                axes[0], axes[1] = axes[1], axes[0]
                return weights.transpose(axes)
            # shape (for MLP): (neurons current layer by neurons previous layer)
            return weights

    def get_mapping(self, mu, nu):
        raise NotImplementedError
