import copy
import torch.nn as nn

from neural_networks.base_model import BaseModel
import torch
import numpy as np


class FusionModel(BaseModel):
    def __init__(self, model1, model2, method, data=None, lambdas=None,
                 sigmas=None, out_ens=False, reverse=False, pgd=False):
        super(FusionModel, self).__init__()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        if lambdas is None:
            lambdas = [0.5, 0.5]
        self.models = [model1, model2]
        self.method = method
        self.data = data
        self.lambdas = lambdas
        self.layers1 = model1.get_layer_names()
        self.layers2 = model2.get_layer_names()
        self.layers_with_weights1 = model1.get_layer_names_with_weights()
        self.layers_with_weights2 = model2.get_layer_names_with_weights()
        self.fused_layers = []
        self.input_size = self.models[0].input_size
        self.sigmas = sigmas
        self.out_ens = out_ens
        self.reverse=reverse
        self.pgd=pgd
        self.non_zero_weights = 0
        self.fuse_two_models(model1, model2)
        self.n1 = 0
        self.n2 = 0
        self.n3 = 0
        self.res_layers_merged = {}


    def fuse_two_models(self, model1, model2):
        for layer1, layer2 in zip(self.layers1, self.layers2):
            # assume everything is linear
            if layer1 not in self.layers_with_weights1:
                self.fused_layers.append(model1.get_layer_by_name(layer1))
                continue
            elif layer1 in model1.get_residual_layers()[0]:
                continue
            # here i assume A, B, C to be (numpy) matrices and for the neurons to be sorted
            # A are the weights of the fused layer, B of the first model, C of the second model
            A_1, A_2, A_3 = self.method.combine_layers(model1, model2, self.data, self.lambdas, layer1,
                                                       sigmas=self.sigmas, out_ens=self.out_ens, reverse=self.reverse, pgd=self.pgd)
            if isinstance(model2.get_layer_by_name(layer2), nn.Linear):
                self.non_zero_weights += A_1.size + A_2.size + A_3.size
                self.fuse_linear_layer(A_1, A_2, A_3, layer1)
            elif isinstance(model2.get_layer_by_name(layer2), nn.Conv2d):
                self.non_zero_weights += A_1.size + A_2.size + A_3.size
                self.fuse_conv_layer(A_1, A_2, A_3, layer1, model1)
            else:
                print("unknown layer: ", model2.get_layer_by_name(layer2))

        self.fused_layers = torch.nn.ModuleList(self.fused_layers)

    def fuse_conv_layer(self, A_1, A_2, A_3, layer1, model1):
        c_out1, c_in1, kH, kW = A_1.shape
        c_out2, c_in2, _, _ = A_2.shape
        c_out3, c_in3, _, _ = A_3.shape

        m, n = c_out1, c_in1 + c_in2 + c_in3
        m2 = c_out1 - c_out3
        m3 = c_out1 - c_out2
        m1 = abs(c_out1 - c_out3 - c_out2)

        self.n1 = c_in1
        self.n2 = c_in2
        self.n3 = c_in3

        if layer1 != self.layers_with_weights1[-1] and layer1 != self.layers_with_weights1[0]:
            # middle layers
            A = np.zeros((m, n, kH, kW), dtype=A_1.dtype)
            A[:, :c_in1] = A_1
            A[:(m1 + m2), c_in1:(c_in1 + c_in2)] = A_2
            A[:m1, (c_in1 + c_in2):] = A_3[:m1]
            A[(m1 + m2):, (c_in1 + c_in2):] = A_3[m1:]

        elif layer1 == self.layers_with_weights1[0]:
            # first layer
            A = A_1

        else:
            # last layer
            m = m1
            A = np.zeros((m, c_in1 + c_in2 + c_in3, kH, kW), dtype=A_1.dtype)
            A[:, :c_in1] = self.final_weighting[0] * A_1
            A[:, c_in1:(c_in1 + c_in2)] = self.final_weighting[1] * A_2
            A[:, (c_in1 + c_in2):] = self.final_weighting[2] * A_3

        # Create fused Conv2d layer
        pad = model1.get_layer_by_name(layer1).padding
        stride = model1.get_layer_by_name(layer1).stride
        new_layer = nn.Conv2d(in_channels=A.shape[1], out_channels=A.shape[0], kernel_size=(kH, kW), stride=stride,
                              padding=pad, bias=False, device=self.device)
        with torch.no_grad():
            new_layer.weight.copy_(torch.from_numpy(A))
        self.fused_layers.append(new_layer)


    def fill_vector(self, w1, w2, w3, n1, n2, n3):
        n = n1 + n2 + n3
        W = np.zeros(n)
        W[:n1] = w1
        W[n1:n1 + n2] = w2
        W[n1 + n2:n] = w3
        return torch.from_numpy(W)


    def fuse_linear_layer(self, A_1, A_2, A_3, layer1):
        if self.reverse:
            n, n2, n3 = A_1.shape[1], A_1.shape[1] - A_3.shape[1], A_1.shape[1] - A_2.shape[1]
            m1, m2, m3 = A_1.shape[0], A_2.shape[0], A_3.shape[0]
            n1 = abs((A_1.shape[1] - A_3.shape[1] - A_2.shape[1]))
            m = m1 + m2 + m3
            self.n1 = n1
            self.n2 = n2
            self.n3 = n3
        else:
            n1, n2, n3 = A_1.shape[1], A_2.shape[1], A_3.shape[1]
            m, m2, m3 = A_1.shape[0], A_1.shape[0] - A_3.shape[0], A_1.shape[0] - A_2.shape[0]
            m1 = abs((A_1.shape[0] - A_3.shape[0] - A_2.shape[0]))
            n = n1 + n2 + n3
            self.n1 = n1
            self.n2 = n2
            self.n3 = n3

        if layer1 != self.layers_with_weights1[-1] and layer1 != self.layers_with_weights1[0]:
            # middle layers
            A = np.zeros((m, n))
            if self.reverse:
                A[:m1, :] = A_1
                A[m1:(m1+m2), :(n1+n2)] = A_2
                A[(m1+m2):, :n1] = A_3[:, :n1]
                A[(m1+m2):, (n1+n2):] = A_3[:, n1:]
            else:
                A[:, :n1] = A_1
                A[:(m1 + m2), n1:(n1 + n2)] = A_2
                A[:m1, (n1 + n2):] = A_3[:m1, :]
                A[(m1 + m2):, (n1 + n2):] = A_3[m1:, :]
            weights = torch.from_numpy(A)
            new_layer = nn.Linear(n, m, bias=False, device=self.device)
            with torch.no_grad():
                new_layer.weight.copy_(weights)
            self.fused_layers.append(new_layer)

        elif layer1 == self.layers_with_weights1[0]:
            # first layer
            A = np.zeros((m, n))
            if self.reverse:
                A[:m1, :] = A_1
                A[m1:(m1+m2), :] = A_2
                A[(m1 + m2):, :] = A_3
            else:
                A = A_1
            weights = torch.from_numpy(A)
            new_layer = nn.Linear(A.shape[1], A.shape[0], bias=False, device=self.device)
            with torch.no_grad():
                new_layer.weight.copy_(weights)
            self.fused_layers.append(new_layer)

        else:
            # last layer
            if self.out_ens:
                m = m2 + m3
                A = np.zeros((m, n))
                A[:, :n1] = A_1
                A[:m2, n1:(n1 + n2)] = A_2
                A[m2:, (n1 + n2):] = A_3
                weights = torch.from_numpy(A)
                new_layer = nn.Linear(n, m, bias=False, device=self.device)
                with torch.no_grad():
                    new_layer.weight.copy_(weights)
                self.fused_layers.append(new_layer)

            else:
                m = m1
                A = np.zeros((m, n))
                if self.reverse:
                    A = A_1
                else:
                    A[:, :n1] = A_1
                    A[:, n1:(n1 + n2)] = A_2
                    A[:, (n1 + n2):] = A_3
                weights = torch.from_numpy(A)
                new_layer = nn.Linear(n, m, bias=False, device=self.device)
                with torch.no_grad():
                    new_layer.weight.copy_(weights)
                self.fused_layers.append(new_layer)

    def get_layer_names(self):
        layer_names = []
        for name, layer in self.named_modules():
            if name != '' and name != 'fused_layers':
                layer_names.append(name)
        return layer_names

    def get_layer_names_with_weights(self):
        layer_names = []
        for name, layer in self.named_modules():
            # Check if the layer has a 'weight' attribute
            if name != '' and name != 'fused_layers' and hasattr(layer, 'weight'):
                layer_names.append(name)
        return layer_names

    def forward(self, x):
        x = x.to(self.device)
        if self.input_size is not None:
            x = x.view(-1, self.input_size)
        for idx, layer in enumerate(self.fused_layers):
            x = layer(x)
        return x


    def get_prev_fused_weights(self):
        if len(self.fused_layers) > 0:
            counter = -1
            while not hasattr(self.fused_layers[counter], 'weight'):
                counter -= 1
            return self.fused_layers[counter].weight.detach().cpu().numpy()
        else:
            raise Exception("No layers were fused so far")
