import random

import torch
from torch import nn


class CnnAccumulator(nn.Module):

    def __init__(self, weight_init_scheme='uniform', input_channels=3):
        super(CnnAccumulator, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=input_channels, out_channels=1, kernel_size=3, stride=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=3, padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3),
            nn.ReLU(),
        )

        if weight_init_scheme == 'uniform':
            self.uniform_weight_init()
        else:
            raise NotImplementedError

    def forward(self, x):
        x = self.layers(x)
        return x.view(x.size(0), -1)

    def uniform_weight_init(self):
        # Initialize all weights to 1 and biases to 0
        for layer in self.layers:
            if isinstance(layer, nn.Conv2d):
                nn.init.ones_(layer.weight.data)
                nn.init.zeros_(layer.bias.data)

    def complementary_weight_init(self):
        # Note: it is not possible to implement complementary weight init
        # that can perform accumulation with in_channel = 1 and out_channel = 1
        pass


class PartialNonUniformCnnAccumulator(nn.Module):

    def __init__(self, input_channels=3, random_expand_to=3):
        """
        random_expand_to defines the feature map will be expanded to how many "random" maps
        later, the 'random' maps will be merged, so the operation is still summation
        """
        super(PartialNonUniformCnnAccumulator, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=input_channels, out_channels=random_expand_to, kernel_size=3, stride=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=random_expand_to, out_channels=1, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=1, out_channels=random_expand_to, kernel_size=3, stride=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=random_expand_to, out_channels=1, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=1, out_channels=random_expand_to, kernel_size=3, stride=3, padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=random_expand_to, out_channels=1, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=1, out_channels=random_expand_to, kernel_size=3, stride=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=random_expand_to, out_channels=1, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=1, out_channels=random_expand_to, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=random_expand_to, out_channels=1, kernel_size=1, stride=1),
            nn.ReLU(),
        )
        self.random_expand_to = random_expand_to
        self.complementary_weight_init()

    def forward(self, x):
        x = self.layers(x)
        return x.view(x.size(0), -1)

    def complementary_weight_init(self):
        for layer in self.layers:
            if isinstance(layer, nn.Conv2d):

                # Initialize all weights to 1
                if layer.kernel_size == (1, 1):
                    nn.init.ones_(layer.weight.data)
                else:
                    weight_tensor = torch.zeros((self.random_expand_to, 1, 3, 3))
                    for i in range(3):
                        for j in range(3):
                            rand_num = [0.]
                            for k in range(self.random_expand_to - 1):
                                rand_num.append(round(random.random(), 3))
                            rand_num.append(1.)
                            rand_num.sort()
                            for k in range(self.random_expand_to):
                                weight_tensor[k, 0, i, j] = rand_num[k + 1] - rand_num[k]
                    layer.weight.data = weight_tensor

                # Initialize all biases to 0
                nn.init.zeros_(layer.bias.data)


class CnnMultiColorAccumulator(nn.Module):

    def __init__(self, num_colors, weight_init_scheme: str = 'uniform', redundant_channels: int = 0):
        super(CnnMultiColorAccumulator, self).__init__()
        assert redundant_channels >= 0, 'num of redundant channels need to be non negative'
        self.redundant_channels = redundant_channels
        self.layers = nn.Sequential(
            nn.Conv2d(
                in_channels=num_colors + redundant_channels,
                out_channels=num_colors,
                kernel_size=3,
                stride=3,
                padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_colors, out_channels=num_colors, kernel_size=3, stride=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_colors, out_channels=num_colors, kernel_size=3, stride=3, padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_colors, out_channels=num_colors, kernel_size=3, stride=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_colors, out_channels=num_colors, kernel_size=3),
            nn.ReLU(),
        )
        self.num_colors = num_colors

        if weight_init_scheme == 'uniform':
            self.uniform_weight_init()
        elif weight_init_scheme == 'uniform_with_bias':
            self.uniform_weight_init_with_bias()
        else:
            raise NotImplementedError

    def forward(self, x):
        x = self.layers(x)
        return x.view(x.size(0), -1)

    def uniform_weight_init(self):
        # Initialize all weights to 1 and biases to 0
        set_redundant_weight = True if self.redundant_channels > 0 else False
        for layer in self.layers:
            if isinstance(layer, nn.Conv2d):
                nn.init.zeros_(layer.weight.data)
                nn.init.zeros_(layer.bias.data)
                for i in range(self.num_colors):
                    layer.weight.data[i, i, :, :] = torch.ones((3, 3))
                if set_redundant_weight:
                    for i in range(self.redundant_channels):
                        layer.weight.data[:, self.num_colors + i, :, :] = torch.rand(self.num_colors, 3, 3)

                    # only set once for the first layer
                    set_redundant_weight = False

    def uniform_weight_init_with_bias(self):
        """
        this method initializes the CNN module with biased kernel at the first layer,
        then uniform 1s after the first layer.
        """
        first_layer_done_flag = False
        set_redundant_weight = True if self.redundant_channels > 0 else False
        for layer in self.layers:
            if not first_layer_done_flag:
                if isinstance(layer, nn.Conv2d):
                    nn.init.constant_(layer.weight.data, 0.9)
                    nn.init.zeros_(layer.bias.data)
                    for i in range(self.num_colors):
                        layer.weight.data[i, i, :, :] = torch.ones((3, 3))
                    first_layer_done_flag = True
                    if set_redundant_weight:
                        for i in range(set_redundant_weight):
                            layer.weight.data[:, self.num_colors + i, :, :] = torch.rand(self.num_colors, 3, 3)

                        # only set once for the first layer
                        set_redundant_weight = False
            else:
                if isinstance(layer, nn.Conv2d):
                    nn.init.zeros_(layer.weight.data)
                    nn.init.zeros_(layer.bias.data)
                    for i in range(self.num_colors):
                        layer.weight.data[i, i, :, :] = torch.ones((3, 3))

    def complementary_weight_init(self):
        # Note: it is not possible to implement complementary weight init
        # that can perform accumulation with in_channel = 1 and out_channel = 1
        # same reason as for the CNNAccumulator
        pass


class PartialNonUniformCnnMultiColorAccumulator(nn.Module):

    def __init__(self, num_colors, redundant_channels: int = 0, inv_variance: int = 5, random_expand_to: int = 3):
        super(PartialNonUniformCnnMultiColorAccumulator, self).__init__()
        assert redundant_channels >= 0, 'num of redundant channels need to be non negative'
        self.redundant_channels = redundant_channels
        self.inv_variance = inv_variance
        self.layers = nn.Sequential(
            nn.Conv2d(
                in_channels=num_colors + redundant_channels,
                out_channels=num_colors,
                kernel_size=3,
                stride=3,
                padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_colors, out_channels=num_colors * random_expand_to, kernel_size=3, stride=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_colors * random_expand_to, out_channels=num_colors, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=num_colors, out_channels=num_colors * random_expand_to, kernel_size=3, stride=3, padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_colors * random_expand_to, out_channels=num_colors, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_colors, out_channels=num_colors * random_expand_to, kernel_size=3, stride=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_colors * random_expand_to, out_channels=num_colors, kernel_size=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_colors, out_channels=num_colors * random_expand_to, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_colors * random_expand_to, out_channels=num_colors, kernel_size=1, stride=1),
            nn.ReLU(),
        )

        self.num_colors = num_colors
        self.random_expand_to = random_expand_to
        self.complementary_weight_init()

    def forward(self, x):
        x = self.layers(x)
        return x.view(x.size(0), -1)

    def complementary_weight_init(self):
        set_first_layer = True
        set_redundant_weight = True if self.redundant_channels > 0 else False
        for layer in self.layers:
            if isinstance(layer, nn.Conv2d):
                # check flag for redundant channel set up
                if set_first_layer:
                    nn.init.zeros_(layer.weight.data)
                    nn.init.zeros_(layer.bias.data)
                    for i in range(self.num_colors):
                        layer.weight.data[i, i, :, :] = torch.ones((3, 3))
                    if set_redundant_weight:
                        for i in range(self.redundant_channels):
                            layer.weight.data[:, self.num_colors + i, :, :] = torch.rand(self.num_colors, 3, 3)/self.inv_variance

                    # only set once for the first layer
                    set_first_layer, set_redundant_weight = False, False

                # Initialize all weights to 1
                else:
                    if layer.kernel_size == (1, 1):
                        nn.init.zeros_(layer.weight.data)
                        for i in range(self.num_colors):
                            for j in range(self.random_expand_to):
                                layer.weight.data[i, i * self.random_expand_to + j, 0, 0] = 1.
                    else:
                        nn.init.zeros_(layer.weight.data)
                        weight_tensor = layer.weight.data
                        for color_id in range(self.num_colors):
                            for i in range(3):
                                for j in range(3):
                                    rand_num = [0.]
                                    for k in range(self.random_expand_to - 1):
                                        rand_num.append(round(random.random(), 3))
                                    rand_num.append(1.)
                                    rand_num.sort()
                                    for k in range(self.random_expand_to):
                                        weight_tensor[k + color_id * self.random_expand_to, color_id, i,
                                                      j] = rand_num[k + 1] - rand_num[k]

                # Initialize all biases to 0
                nn.init.zeros_(layer.bias.data)


class CNNColorDetector(nn.Module):

    def __init__(self, color_list, redundant_channels: int = 0, background_pixel=(0, 0, 0)):
        super(CNNColorDetector, self).__init__()

        # trick to append background pixel internally
        color_list = list(color_list)
        color_list.append(background_pixel)
        self.color_list = tuple(color_list)

        assert redundant_channels >= 0, 'num of redundant channels cannot be negative'
        self.redundant_channels = redundant_channels
        self.num_colors = len(self.color_list)
        self.first_stage = nn.Conv2d(in_channels=3, out_channels=9 * self.num_colors, kernel_size=1)
        self.second_stage = nn.Conv2d(in_channels=9 * self.num_colors, out_channels=6 * self.num_colors, kernel_size=1)
        self.third_stage = nn.Conv2d(in_channels=6 * self.num_colors, out_channels=3 * self.num_colors, kernel_size=1)
        self.sum_stage = nn.Conv2d(in_channels=3 * self.num_colors, out_channels=self.num_colors, kernel_size=1)

        # final stage removes the background color detector and add redundant channels
        self.sum_and_redundant_stage = nn.Conv2d(
            in_channels=self.num_colors, out_channels=self.num_colors + self.redundant_channels - 1, kernel_size=1)
        self.weight_init()

    def weight_init(self):
        self.set_weight_first_stage()
        self.set_weight_second_stage()
        self.set_weight_third_stage()
        self.set_weight_sum_stage()
        self.set_weight_sum_and_redundant_stage()

    def set_weight_first_stage(self):
        weight_list = [[[[0]] for _ in range(3)] for _ in range(9 * self.num_colors)]
        single_color_weighs = [0, 0, 0, 1, 1, 1, 2, 2, 2]
        for idx, _ in enumerate(weight_list):
            weight_list[idx][single_color_weighs[idx % 9]] = [[1]]
        bias_list = []
        for rgb_color in self.color_list:
            for rgb_val in rgb_color:
                bias_list.append(-rgb_val + 1)
                bias_list.append(-rgb_val)
                bias_list.append(-rgb_val - 1)
        with torch.no_grad():
            self.first_stage.weight.copy_(nn.Parameter(torch.tensor(weight_list, dtype=torch.float32)))
        with torch.no_grad():
            self.first_stage.bias.copy_(nn.Parameter(torch.tensor(bias_list, dtype=torch.float32)))

    def set_weight_second_stage(self):
        weight_list = [[[[0]] for _ in range(9 * self.num_colors)] for _ in range(6 * self.num_colors)]
        for idx in range(self.num_colors):
            weight_list[6 * idx][9 * idx + 0], weight_list[6 * idx][9 * idx + 1] = [[1]], [[-1]]
            weight_list[6 * idx + 1][9 * idx + 1], weight_list[6 * idx + 1][9 * idx + 2] = [[1]], [[-1]]
            weight_list[6 * idx + 2][9 * idx + 3], weight_list[6 * idx + 2][9 * idx + 4] = [[1]], [[-1]]
            weight_list[6 * idx + 3][9 * idx + 4], weight_list[6 * idx + 3][9 * idx + 5] = [[1]], [[-1]]
            weight_list[6 * idx + 4][9 * idx + 6], weight_list[6 * idx + 4][9 * idx + 7] = [[1]], [[-1]]
            weight_list[6 * idx + 5][9 * idx + 7], weight_list[6 * idx + 5][9 * idx + 8] = [[1]], [[-1]]
        with torch.no_grad():
            self.second_stage.weight.copy_(nn.Parameter(torch.tensor(weight_list, dtype=torch.float32)))

        nn.init.zeros_(self.second_stage.bias)

    def set_weight_third_stage(self):
        weight_list = [[[[0]] for _ in range(6 * self.num_colors)] for _ in range(3 * self.num_colors)]
        for idx in range(self.num_colors):
            weight_list[3 * idx][6 * idx + 0], weight_list[3 * idx][6 * idx + 1] = [[1]], [[-1]]
            weight_list[3 * idx + 1][6 * idx + 2], weight_list[3 * idx + 1][6 * idx + 3] = [[1]], [[-1]]
            weight_list[3 * idx + 2][6 * idx + 4], weight_list[3 * idx + 2][6 * idx + 5] = [[1]], [[-1]]
        with torch.no_grad():
            self.third_stage.weight.copy_(nn.Parameter(torch.tensor(weight_list, dtype=torch.float32)))

        nn.init.zeros_(self.third_stage.bias)

    def set_weight_sum_stage(self):
        weight_list = [[[[0]] for _ in range(3 * self.num_colors)] for _ in range(self.num_colors)]
        for idx in range(self.num_colors):
            weight_list[idx][3 * idx:3 * (idx + 1)] = [[[1]], [[1]], [[1]]]
        bias_list = [-2 for _ in range(self.num_colors)]
        with torch.no_grad():
            self.sum_stage.weight.copy_(nn.Parameter(torch.tensor(weight_list, dtype=torch.float32)))
            self.sum_stage.bias.copy_(nn.Parameter(torch.tensor(bias_list, dtype=torch.float32)))

    def set_weight_sum_and_redundant_stage(self):
        nn.init.zeros_(self.sum_and_redundant_stage.weight.data)

        # remove background pixel detection
        true_num_colors = self.num_colors - 1
        for i in range(true_num_colors):
            self.sum_and_redundant_stage.weight.data[i, i, 0, 0] = 1.

        # redundant channel don't count background pixel in,
        # as we consider background pixels as in distribution
        for i in range(self.num_colors):
            for j in range(self.redundant_channels):
                self.sum_and_redundant_stage.weight.data[true_num_colors + j, i, 0, 0] = -1.

        nn.init.zeros_(self.sum_and_redundant_stage.bias.data)
        for i in range(self.redundant_channels):
            self.sum_and_redundant_stage.bias.data[true_num_colors + i] = 1

    def forward(self, x):
        y = torch.relu(self.first_stage(x))
        y = torch.relu(self.second_stage(y))
        y = torch.relu(self.third_stage(y))
        y = torch.relu(self.sum_stage(y))
        y = torch.relu(self.sum_and_redundant_stage(y))
        return y
