import torch
import torch.nn.functional as F
from torch import nn


class MultiLayerFC(nn.Module):

    def __init__(self, input_size, num_hidden_layers, output_size, hidden_size=2000, activation='relu'):
        super(MultiLayerFC, self).__init__()
        self.activation = activation

        # Create a list of Linear layers with the specified hidden sizes
        self.linears = nn.ModuleList([nn.Linear(input_size, hidden_size)])
        self.linears.extend([nn.Linear(hidden_size, hidden_size) for _ in range(num_hidden_layers)])
        self.linears.append(nn.Linear(hidden_size, output_size))

    def forward(self, x):
        # Pass input through each Linear layer in the list
        for layer in self.linears[:-1]:
            x = layer(x)
            if self.activation == 'relu':
                x = torch.relu(x)
            elif self.activation == 'sigmoid':
                x = torch.sigmoid(x)
            else:
                raise NotImplementedError
        x = self.linears[-1](x)
        return x


class IdentityMLP(nn.Module):

    def __init__(self, input_size, division_scale: int = 1):
        super(IdentityMLP, self).__init__()
        self.fc = nn.Linear(input_size, input_size)
        with torch.no_grad():
            self.fc.weight.copy_(nn.Parameter(torch.eye(input_size, dtype=torch.float32) / division_scale))
        nn.init.zeros_(self.fc.bias)

    def forward(self, x):
        out = self.fc(x)
        return out


class DecisionHead(nn.Module):
    """ this module is a learned module to perform mapping from number to class """

    # TODO this should be integrated with learn modulo model
    def __init__(self, rank_increase, rank_to_increase, rank_increase_layer, num_hidden_layers, num_classes):
        super().__init__()
        self.rank_increaser = MultiLayerFC(
            1, rank_increase_layer, rank_to_increase, activation='sigmoid') if rank_increase else None

        true_input_size = rank_to_increase if rank_increase else 1
        self.classifier = MultiLayerFC(true_input_size, num_hidden_layers, num_classes)

    def forward(self, x):
        if self.rank_increaser is None:
            pass
        else:
            x = self.rank_increaser((x / 400) - 0.5)
        return self.classifier(x)

    def load_model_parameters(self, model_path):
        state_dict = torch.load(model_path)['state_dict']
        self.load_state_dict(state_dict)


class SyntheticModel(nn.Module):

    def __init__(self, accumulator, decision_head, color_detector=None, softmax=False):
        super().__init__()
        self.color_detector = color_detector
        self.accumulator = accumulator
        self.decision_head = decision_head
        self.softmax = nn.Softmax(dim=-1) if softmax else None

    def forward(self, x):
        output = self.color_detector(x) if self.color_detector is not None else x
        output = self.decision_head(self.accumulator(output))
        if self.softmax is not None:
            return self.softmax(output)
        else:
            return output


class IdentifyNumberModel(nn.Module):
    """ This model only activate (return 1) in response to a specific number """

    def __init__(self, target_number):
        super().__init__()
        self.target_number = target_number
        self.first_stage = nn.Linear(1, 3)
        self.second_stage = nn.Linear(3, 2)
        self.third_stage = nn.Linear(2, 1)
        self.set_weight()

    def set_weight(self):
        nn.init.ones_(self.first_stage.weight)
        with torch.no_grad():
            self.first_stage.bias.copy_(
                nn.Parameter(
                    torch.tensor(
                        [-(self.target_number - 1), -self.target_number, -self.target_number - 1],
                        dtype=torch.float32)))
            self.second_stage.weight.copy_(nn.Parameter(torch.tensor([[1, -1, 0], [0, 1, -1]], dtype=torch.float32)))
            self.third_stage.weight.copy_(nn.Parameter(torch.tensor([1, -1], dtype=torch.float32)))
        nn.init.zeros_(self.second_stage.bias)
        nn.init.zeros_(self.third_stage.bias)

    def forward(self, x):
        return torch.relu(self.third_stage(torch.relu(self.second_stage(torch.relu(self.first_stage(x))))))


class IdentifyAndSubtractModel(nn.Module):

    def __init__(self, input_dim, modulo_number):
        super().__init__()
        self.input_dim = input_dim
        self.modulo_number = modulo_number
        self.first_stage = nn.Linear(self.input_dim, 4 * self.input_dim)
        self.second_stage = nn.Linear(4 * self.input_dim, 3 * self.input_dim)
        self.third_stage = nn.Linear(3 * self.input_dim, 2 * self.input_dim)
        self.fourth_stage = nn.Linear(2 * self.input_dim, self.input_dim)
        self.set_weight()

    def set_weight(self):
        self.set_weight_first_stage()
        self.set_weight_second_stage()
        self.set_weight_third_stage()
        self.set_weight_fourth_stage()

    def set_weight_first_stage(self):
        weight_list = [[0 for _ in range(self.input_dim)] for _ in range(4 * self.input_dim)]
        for idx in range(self.input_dim):
            weight_list[4 * idx][idx] = 1
            weight_list[4 * idx + 1][idx] = 1
            weight_list[4 * idx + 2][idx] = 1
            weight_list[4 * idx + 3][idx] = 1
        bias_list = [0 for _ in range(4 * self.input_dim)]
        bias_to_set = [0, -self.modulo_number + 1, -self.modulo_number, -self.modulo_number - 1]
        for idx, _ in enumerate(bias_list):
            bias_list[idx] = bias_to_set[idx % 4]
        with torch.no_grad():
            self.first_stage.bias.copy_(nn.Parameter(torch.tensor(bias_list, dtype=torch.float32)))
            self.first_stage.weight.copy_(nn.Parameter(torch.tensor(weight_list, dtype=torch.float32)))

    def set_weight_second_stage(self):
        nn.init.zeros_(self.second_stage.bias)
        weight_list = [[0 for _ in range(4 * self.input_dim)] for _ in range(3 * self.input_dim)]
        for idx in range(self.input_dim):
            weight_list[3 * idx][4 * idx] = 1
            weight_list[3 * idx + 1][4 * idx + 1] = 1
            weight_list[3 * idx + 1][4 * idx + 2] = -1
            weight_list[3 * idx + 2][4 * idx + 2] = 1
            weight_list[3 * idx + 2][4 * idx + 3] = -1

        with torch.no_grad():
            self.second_stage.weight.copy_(nn.Parameter(torch.tensor(weight_list, dtype=torch.float32)))

    def set_weight_third_stage(self):
        nn.init.zeros_(self.third_stage.bias)
        weight_list = [[0 for _ in range(3 * self.input_dim)] for _ in range(2 * self.input_dim)]
        for idx in range(self.input_dim):
            weight_list[2 * idx][3 * idx] = 1
            weight_list[2 * idx + 1][3 * idx + 1] = 1
            weight_list[2 * idx + 1][3 * idx + 2] = -1

        with torch.no_grad():
            self.third_stage.weight.copy_(nn.Parameter(torch.tensor(weight_list, dtype=torch.float32)))

    def set_weight_fourth_stage(self):
        nn.init.zeros_(self.fourth_stage.bias)
        weight_to_set = [1, -self.modulo_number]
        weight_list = [[0 for _ in range(2 * self.input_dim)] for _ in range(self.input_dim)]
        for idx in range(len(weight_list)):
            weight_list[idx][2 * idx] = weight_to_set[0]
            weight_list[idx][2 * idx + 1] = weight_to_set[1]

        with torch.no_grad():
            self.fourth_stage.weight.copy_(nn.Parameter(torch.tensor(weight_list, dtype=torch.float32)))

    def forward(self, x):
        output = torch.relu(self.first_stage(x))
        output = torch.relu(self.second_stage(output))
        output = torch.relu(self.third_stage(output))
        output = self.fourth_stage(output)
        return output


class ModuloModel(nn.Module):
    """
    This model performs modulo calculation. It works by applying 4 stages of calculation.
    Firstly, subtract N, 2N ... until (max_num // N + 1) N. Here, (max_num // N + 1) + 1 neurons.
    Secondly, output of each neuron subtracts the result of the next neuron, then we have
    outputs consist of 0, mod_result, and N the modulo number.
    Thirdly, we use a modified identify network to turn number N to zero. Specifically,
    if detects N, return detect_outcome = 1, else 0. Then, perform input - detect_outcome * N.
    Lastly, sum over all channels, it would be either 0 or the mod_result, so sum is mod_result.
    """

    def __init__(self, modulo_number, max_number, classification_mode=False):
        super().__init__()
        self.modulo_number = modulo_number
        self.max_number = max_number
        self.classification_mode = classification_mode
        self.num_first_stage_neurons = self.max_number // self.modulo_number + 1
        self.first_stage = nn.Linear(1, self.num_first_stage_neurons + 1)
        self.second_stage = nn.Linear(self.num_first_stage_neurons + 1, self.num_first_stage_neurons)
        self.third_stage = IdentifyAndSubtractModel(self.num_first_stage_neurons, self.modulo_number)
        if not self.classification_mode:
            self.sum_stage = nn.Linear(self.num_first_stage_neurons, 1)
        else:
            self.fourth_stage = nn.Linear(self.num_first_stage_neurons, 3 * self.modulo_number)
            self.fifth_stage = nn.Linear(3 * self.modulo_number, 2 * self.modulo_number)
            self.classification_stage = nn.Linear(2 * self.modulo_number, self.modulo_number)
        self.set_weight()

    def set_weight(self):
        self.set_weight_first_stage()
        self.set_weight_second_stage()
        self.set_weight_third_stage()
        if not self.classification_mode:
            self.set_weight_sum_stage()
        else:
            self.set_weight_fourth_stage()
            self.set_weight_fifth_stage()
            self.set_weight_classification_stage()

    def set_weight_first_stage(self):
        nn.init.ones_(self.first_stage.weight)
        bias_list = [-i * self.modulo_number for i in range(self.num_first_stage_neurons + 1)]
        with torch.no_grad():
            self.first_stage.bias.copy_(nn.Parameter(torch.tensor(bias_list, dtype=torch.float32)))

    def set_weight_second_stage(self):
        nn.init.zeros_(self.second_stage.bias)
        weight_list = [
            [0 for _ in range(self.num_first_stage_neurons + 1)] for _ in range(self.num_first_stage_neurons)
        ]
        for idx, _ in enumerate(weight_list):
            weight_list[idx][idx] = 1
            weight_list[idx][idx + 1] = -1
        with torch.no_grad():
            self.second_stage.weight.copy_(nn.Parameter(torch.tensor(weight_list, dtype=torch.float32)))

    def set_weight_third_stage(self):
        """ no need to set, already taken care by submodule """
        pass

    def set_weight_sum_stage(self):
        nn.init.ones_(self.sum_stage.weight)
        nn.init.zeros_(self.sum_stage.bias)

    def set_weight_fourth_stage(self):
        nn.init.ones_(self.fourth_stage.weight)
        bias_list = [0 for _ in range(3 * self.modulo_number)]
        for idx in range(self.modulo_number):
            bias_list[3 * idx] = -idx + 1
            bias_list[3 * idx + 1] = -idx
            bias_list[3 * idx + 2] = -idx - 1
        with torch.no_grad():
            self.fourth_stage.bias.copy_(nn.Parameter(torch.tensor(bias_list, dtype=torch.float32)))

    def set_weight_fifth_stage(self):
        nn.init.zeros_(self.fifth_stage.bias)
        weight_list = [[0 for _ in range(3 * self.modulo_number)] for _ in range(2 * self.modulo_number)]
        for idx in range(self.modulo_number):
            weight_list[2 * idx][3 * idx] = 1
            weight_list[2 * idx][3 * idx + 1] = -1
            weight_list[2 * idx + 1][3 * idx + 1] = 1
            weight_list[2 * idx + 1][3 * idx + 2] = 1
        with torch.no_grad():
            self.fifth_stage.weight.copy_(nn.Parameter(torch.tensor(weight_list, dtype=torch.float32)))

    def set_weight_classification_stage(self):
        nn.init.zeros_(self.classification_stage.bias)
        weight_list = [[0 for _ in range(2 * self.modulo_number)] for _ in range(self.modulo_number)]
        for idx in range(self.modulo_number):
            weight_list[idx][2 * idx] = 1
            weight_list[idx][2 * idx + 1] = -1
        with torch.no_grad():
            self.classification_stage.weight.copy_(nn.Parameter(torch.tensor(weight_list, dtype=torch.float32)))

    def forward(self, x):
        y = torch.relu(self.first_stage(x))
        y = torch.relu(self.second_stage(y))
        y = torch.relu(self.third_stage(y))
        if not self.classification_mode:
            y = self.sum_stage(y)
        else:
            y = torch.relu(self.fourth_stage(y))
            y = torch.relu(self.fifth_stage(y))
            y = torch.relu(self.classification_stage(y))
        return y
