import torch
from torch import nn, Tensor
import torch.nn.functional as F
from numpy.ma import floor


class BinaryFuzzyRelationNetwork(nn.Module):
    def __init__(self, epsilon=1e-15):
        """

        :param epsilon: threshold,  Avoid numerical overflow
        """
        super(BinaryFuzzyRelationNetwork, self).__init__()
        self.__epsilon = epsilon

    def forward(self, X: Tensor) -> Tensor:
        """

        :param X: shape=[n, d], 0 <= X[i,j]
        :return:  shape=[n, n], 0 <= S[i,j] <= 1, S[i,i] = 1, S[i,j] = S[j,i]
        """
        S = torch.mm(input=X, mat2=torch.transpose(input=X, dim0=0, dim1=1))  # shape=[n, n]
        D = torch.diag(input=S)  # shape=[n]
        D = torch.sqrt(input=D)
        D = torch.clip(input=D, min=self.__epsilon)  # Avoid numerical overflow
        D = 1.0 / D
        D = torch.diag_embed(input=D)  # shape=[n, n]
        S = torch.mm(input=D, mat2=S)
        S = torch.mm(input=S, mat2=D)
        S = torch.clip(input=S, min=0.0, max=1.0)
        return S

    def forward_X1_X2(self, X1: Tensor, X2: Tensor) -> Tensor:
        """

        :param X1: shape=[n1, d], 0 <= X1[i,j]
        :param X2: shape=[n2, d], 0 <= X2[i,j]
        :return:  shape=[n1, n2], 0<= S <= 1
        """
        S = torch.mm(input=X1, mat2=torch.transpose(input=X2, dim0=0, dim1=1))  # shape=[n1, n2]
        D1 = torch.diag(input=torch.mm(input=X1, mat2=torch.transpose(input=X1, dim0=0, dim1=1)))  # shape=[n1]
        D1 = torch.sqrt(input=D1)
        D1 = torch.clip(input=D1, min=self.__epsilon)  # Avoid numerical overflow
        D1 = 1.0 / D1
        D1 = torch.diag_embed(input=D1)  # shape=[n1, n1]
        D2 = torch.diag(input=torch.mm(input=X2, mat2=torch.transpose(input=X2, dim0=0, dim1=1)))  # shape=[n2]
        D2 = torch.sqrt(input=D2)
        D2 = torch.clip(input=D2, min=self.__epsilon)  # Avoid numerical overflow
        D2 = 1.0 / D2
        D2 = torch.diag_embed(input=D2)  # shape=[n2, n2]
        S = torch.mm(input=D1, mat2=S)
        S = torch.mm(input=S, mat2=D2)
        S = torch.clip(input=S, min=0.0, max=1.0)
        return S


class FuzzyPermissibleLoss(nn.Module):
    """
    F: Fuzzy
    E: Equivalence
    S: Similarity
    R: Relation
    Mat: Matrix
    """

    def __init__(self, alpha=0.05, beta=0.9):
        """
        mean in it,
        0 =< alpha < 0.5 < beta <= 1
        :param alpha: float
        :param beta: float
        """
        super(FuzzyPermissibleLoss, self).__init__()
        self.__alpha = alpha
        self.__beta = beta

    def extra_repr(self) -> str:
        return 'alpha={}, beta={}'.format(
            self.__alpha, self.__beta
        )

    def forward(self, FR_Mat: Tensor, class_ER_Mat: Tensor):
        """
        mean in it
        :param FR_Mat: shape=[n1, n2], 0 <= FR_Mat[i,j] <=1,
                       the predicting fuzzy  relation matrix
                       Fuzzy Similarity Relation Matrix or Fuzzy Equivalence Relation Matrix
                       (i.e. the transitive closure the predicting fuzzy similarity relation matrix)
        :param class_ER_Mat: shape=[n1, n2], class_ER_Mat[i,j] = 0 or 1,
                             the equivalence relation matrix derived by class label
        :return:
        """
        one_indices = class_ER_Mat == 1.0
        # if FER_Mat[i,j] <= alpha, then FSR_Mat[i,j] <= alpha
        diff_0 = FR_Mat[~one_indices] - self.__alpha
        diff_0 = F.relu(input=diff_0)
        loss_0 = torch.sum(input=diff_0)

        # if FSR_Mat[i,j] >= beta, then FER_Mat[i,j] >= beta
        diff_1 = self.__beta - FR_Mat[one_indices]
        diff_1 = F.relu(input=diff_1)
        loss_1 = torch.sum(input=diff_1)

        loss = (loss_0 + loss_1) / FR_Mat.numel()
        return loss


class L2NormRegularization(nn.Module):
    """
    0.5 and mean in it
    """

    def __init__(self):
        super(L2NormRegularization, self).__init__()

    def forward(self, nn_brm, device=torch.device('cpu')) -> Tensor:
        loss = torch.tensor(data=0.0, device=device, requires_grad=True)
        for name, param in nn_brm.named_parameters():
            if 'Linear' in name:
                loss = loss + torch.pow(input=torch.norm(input=param, p="fro"), exponent=2) / param.numel()
        return 0.5 * loss


class FullConnectNetwork(nn.Module):
    def __init__(self, in_features, layer_structure_list, network_name='FCN'):
        """
        :param in_features: int, the number of input features
        :param layer_structure_list: the structure of every layers
               layer_structure_list[l]: the structure of l-th layer, dict
               layer_structure_list[l]['node_num']: int, the number of hidden neural unit
               layer_structure_list[l]['has_bias']: bool
               layer_structure_list[l]['has_BN']: bool
               layer_structure_list[l]['activation']: None or nn.Module
               layer_structure_list[l]['has_dropout']: bool
        :param network_name: string, the name of network
        """
        super(FullConnectNetwork, self).__init__()
        self.__net = nn.Sequential()
        for l in range(len(layer_structure_list)):
            structure = layer_structure_list[l]
            self.__net.add_module(name=network_name + '-layer-' + str(l + 1) + '-Linear',
                                  module=nn.Linear(in_features=in_features,
                                                   out_features=structure['node_num'],
                                                   bias=structure['has_bias'])
                                  )
            in_features = structure['node_num']

            if structure['has_BN']:
                self.__net.add_module(name=network_name + '-layer-' + str(l + 1) + '-BN',
                                      module=nn.BatchNorm1d(num_features=structure['node_num'])
                                      )
            if structure['activation'] is not None:
                # self.__net.add_module(name=network_name + '-layer-' + str(l + 1) + '-Activation',
                #                       module=nn.ReLU())
                self.__net.add_module(name=network_name + '-layer-' + str(l + 1) + '-Activation',
                                      module=structure['activation'])

            if structure['has_dropout']:
                self.__net.add_module(name=network_name + '-layer-' + str(l + 1) + '-Dropout',
                                      module=nn.Dropout())

    def forward(self, x):
        return self.__net(x)


class ConvolutionNeuralNetwork(nn.Module):
    def __init__(self, in_high, in_width, in_channels, layer_structure_list, network_name='CNN'):
        """
        :param layer_structure_list: the structure of every layer
               layer_structure_list[l]: structure of the l-th layer network, dict
               layer_structure_list[l]['function']:  belongs to {'conv', 'max-pool', 'avg-pool'}
               layer_structure_list[l]['channels_num']: int, the number of channels of input image, only for 'conv'
               layer_structure_list[l]['conv_deep']: int,  he number of channels of output image,  only for 'conv'
               layer_structure_list[l]['has_bias']: bool, only for 'conv'

               layer_structure_list[l]['kernel_size_h']: int, the size of kernel
               layer_structure_list[l]['kernel_size_w']: int, the size of kernel
               layer_structure_list[l]['stride_h']: int
               layer_structure_list[l]['stride_w']: int
               layer_structure_list[l]['has_BN']: bool,
                   when it is True, the layer_structure_list[l]['channels_num'] must be a int
               layer_structure_list[l]['activation']: None or nn.Module
               layer_structure_list[l]['has_dropout']: bool
        """
        super(ConvolutionNeuralNetwork, self).__init__()
        self.__output_shape = {'high': in_high,
                               'width': in_width,
                               'channels_num': in_channels}
        self.__net = nn.Sequential()
        for l in range(len(layer_structure_list)):
            structure = layer_structure_list[l]
            if structure['function'] == 'conv':
                if 'padding_h' not in structure:
                    structure['padding_h'] = 0
                if 'padding_w' not in structure:
                    structure['padding_w'] = 0
                self.__net.add_module(name=network_name + '-layer-' + str(l + 1) + '-conv',
                                      module=nn.Conv2d(in_channels=structure['channels_num'],
                                                       out_channels=structure['conv_deep'],
                                                       kernel_size=(structure['kernel_size_h'],
                                                                    structure['kernel_size_w']),
                                                       stride=(structure['stride_h'],
                                                               structure['stride_w']),
                                                       padding=(structure['padding_h'],
                                                                structure['padding_w']),
                                                       bias=structure['has_bias'])
                                      )
                dilation = 1.0
                fz = self.__output_shape['high'] + 2 * structure['padding_h'] - dilation * (
                        structure['kernel_size_h'] - 1) - 1.0
                self.__output_shape['high'] = int(floor(fz / structure['stride_h'] + 1.0))
                fz = self.__output_shape['width'] + 2 * structure['padding_w'] - dilation * (
                        structure['kernel_size_w'] - 1) - 1.0
                self.__output_shape['width'] = int(floor(fz / structure['stride_w'] + 1.0))
                self.__output_shape['channels_num'] = structure['conv_deep']

                bn_num_features = structure['conv_deep']
            elif structure['function'] == 'max-pool':
                if 'padding_h' not in structure:
                    structure['padding_h'] = 0
                if 'padding_w' not in structure:
                    structure['padding_w'] = 0
                self.__net.add_module(name=network_name + '-layer-' + str(l + 1) + '-max-pool',
                                      module=nn.MaxPool2d(kernel_size=(structure['kernel_size_h'],
                                                                       structure['kernel_size_w']),
                                                          stride=(structure['stride_h'],
                                                                  structure['stride_w']),
                                                          padding=(structure['padding_h'],
                                                                   structure['padding_w'])
                                                          )
                                      )
                dilation = 1.0
                fz = self.__output_shape['high'] + 2 * structure['padding_h'] - dilation * (
                        structure['kernel_size_h'] - 1) - 1.0
                self.__output_shape['high'] = int(floor(fz / structure['stride_h'] + 1.0))
                fz = self.__output_shape['width'] + 2 * structure['padding_w'] - dilation * (
                        structure['kernel_size_w'] - 1) - 1.0
                self.__output_shape['width'] = int(floor(fz / structure['stride_w'] + 1.0))

                bn_num_features = structure['channels_num']
            elif structure['function'] == 'avg-pool':
                self.__net.add_module(name=network_name + '-layer-' + str(l + 1) + '-avg-pool',
                                      module=nn.AvgPool2d(kernel_size=(structure['kernel_size_h'],
                                                                       structure['kernel_size_w']),
                                                          stride=(structure['stride_h'],
                                                                  structure['stride_w']))
                                      )
                padding = 0.0
                dilation = 1.0
                fz = self.__output_shape['high'] + 2 * padding - dilation * (structure['kernel_size_h'] - 1) - 1.0
                self.__output_shape['high'] = int(floor(fz / structure['stride_h'] + 1.0))
                fz = self.__output_shape['width'] + 2 * padding - dilation * (structure['kernel_size_w'] - 1) - 1.0
                self.__output_shape['width'] = int(floor(fz / structure['stride_w'] + 1.0))

                bn_num_features = structure['channels_num']
            else:
                raise Exception("Invalid function param " + structure['function'])

            if structure['has_BN']:
                self.__net.add_module(name=network_name + '-layer-' + str(l + 1) + '-BN',
                                      module=nn.BatchNorm2d(num_features=bn_num_features))

            if structure['activation'] is not None:
                self.__net.add_module(name=network_name + '-layer-' + str(l + 1) + '-Activation',
                                      module=structure['activation'])

            if structure['has_dropout']:
                self.__net.add_module(name=network_name + '-layer-' + str(l + 1) + '-Dropout',
                                      module=nn.Dropout())

    def forward(self, x):
        return self.__net(x)

    def get_output_shape(self):
        return self.__output_shape


class ModelM3(nn.Module):
    """
    References
    [1] Sanghyeon An, Minjun Lee, Sanglee Park, Heerin Yang, Jungmin So.
        An Ensemble of Simple Convolutional Neural Network Models for MNIST Digit Recognition.
        https://doi.org/10.48550/arXiv.2008.10400
    The code is download from: https://github.com/ansh941/MnistSimpleCNN
    """
    def __init__(self):
        super(ModelM3, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, bias=False)  # output becomes 26x26
        self.conv1_bn = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 48, 3, bias=False)  # output becomes 24x24
        self.conv2_bn = nn.BatchNorm2d(48)
        self.conv3 = nn.Conv2d(48, 64, 3, bias=False)  # output becomes 22x22
        self.conv3_bn = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 80, 3, bias=False)  # output becomes 20x20
        self.conv4_bn = nn.BatchNorm2d(80)
        self.conv5 = nn.Conv2d(80, 96, 3, bias=False)  # output becomes 18x18
        self.conv5_bn = nn.BatchNorm2d(96)
        self.conv6 = nn.Conv2d(96, 112, 3, bias=False)  # output becomes 16x16
        self.conv6_bn = nn.BatchNorm2d(112)
        self.conv7 = nn.Conv2d(112, 128, 3, bias=False)  # output becomes 14x14
        self.conv7_bn = nn.BatchNorm2d(128)
        self.conv8 = nn.Conv2d(128, 144, 3, bias=False)  # output becomes 12x12
        self.conv8_bn = nn.BatchNorm2d(144)
        self.conv9 = nn.Conv2d(144, 160, 3, bias=False)  # output becomes 10x10
        self.conv9_bn = nn.BatchNorm2d(160)
        self.conv10 = nn.Conv2d(160, 176, 3, bias=False)  # output becomes 8x8
        self.conv10_bn = nn.BatchNorm2d(176)
        self.fc1 = nn.Linear(11264, 10, bias=False)
        self.fc1_bn = nn.BatchNorm1d(10)

    def get_logits(self, x):
        x = (x - 0.5) * 2.0
        conv1 = F.relu(self.conv1_bn(self.conv1(x)))
        conv2 = F.relu(self.conv2_bn(self.conv2(conv1)))
        conv3 = F.relu(self.conv3_bn(self.conv3(conv2)))
        conv4 = F.relu(self.conv4_bn(self.conv4(conv3)))
        conv5 = F.relu(self.conv5_bn(self.conv5(conv4)))
        conv6 = F.relu(self.conv6_bn(self.conv6(conv5)))
        conv7 = F.relu(self.conv7_bn(self.conv7(conv6)))
        conv8 = F.relu(self.conv8_bn(self.conv8(conv7)))
        conv9 = F.relu(self.conv9_bn(self.conv9(conv8)))
        conv10 = F.relu(self.conv10_bn(self.conv10(conv9)))
        flat1 = torch.flatten(conv10.permute(0, 2, 3, 1), 1)
        logits = self.fc1_bn(self.fc1(flat1))
        return logits

    def forward(self, x):
        logits = self.get_logits(x)
        return F.log_softmax(logits, dim=1)


class ModelM5(nn.Module):
    """
    References
    [1] Sanghyeon An, Minjun Lee, Sanglee Park, Heerin Yang, Jungmin So.
        An Ensemble of Simple Convolutional Neural Network Models for MNIST Digit Recognition.
        https://doi.org/10.48550/arXiv.2008.10400
    The code is download from: https://github.com/ansh941/MnistSimpleCNN
    """
    def __init__(self):
        super(ModelM5, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, bias=False)
        self.conv1_bn = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 5, bias=False)
        self.conv2_bn = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 96, 5, bias=False)
        self.conv3_bn = nn.BatchNorm2d(96)
        self.conv4 = nn.Conv2d(96, 128, 5, bias=False)
        self.conv4_bn = nn.BatchNorm2d(128)
        self.conv5 = nn.Conv2d(128, 160, 5, bias=False)
        self.conv5_bn = nn.BatchNorm2d(160)
        self.fc1 = nn.Linear(10240, 10, bias=False)
        self.fc1_bn = nn.BatchNorm1d(10)

    def get_logits(self, x):
        x = (x - 0.5) * 2.0
        conv1 = F.relu(self.conv1_bn(self.conv1(x)))
        conv2 = F.relu(self.conv2_bn(self.conv2(conv1)))
        conv3 = F.relu(self.conv3_bn(self.conv3(conv2)))
        conv4 = F.relu(self.conv4_bn(self.conv4(conv3)))
        conv5 = F.relu(self.conv5_bn(self.conv5(conv4)))
        flat5 = torch.flatten(conv5.permute(0, 2, 3, 1), 1)
        logits = self.fc1_bn(self.fc1(flat5))
        return logits

    def forward(self, x):
        logits = self.get_logits(x)
        return F.log_softmax(logits, dim=1)


class ModelM7(nn.Module):
    """
    References
    [1] Sanghyeon An, Minjun Lee, Sanglee Park, Heerin Yang, Jungmin So.
        An Ensemble of Simple Convolutional Neural Network Models for MNIST Digit Recognition.
        https://doi.org/10.48550/arXiv.2008.10400
    The code is download from: https://github.com/ansh941/MnistSimpleCNN
    """
    def __init__(self):
        super(ModelM7, self).__init__()
        self.conv1 = nn.Conv2d(1, 48, 7, bias=False)  # output becomes 22x22
        self.conv1_bn = nn.BatchNorm2d(48)
        self.conv2 = nn.Conv2d(48, 96, 7, bias=False)  # output becomes 16x16
        self.conv2_bn = nn.BatchNorm2d(96)
        self.conv3 = nn.Conv2d(96, 144, 7, bias=False)  # output becomes 10x10
        self.conv3_bn = nn.BatchNorm2d(144)
        self.conv4 = nn.Conv2d(144, 192, 7, bias=False)  # output becomes 4x4
        self.conv4_bn = nn.BatchNorm2d(192)
        self.fc1 = nn.Linear(3072, 10, bias=False)
        self.fc1_bn = nn.BatchNorm1d(10)

    def get_logits(self, x):
        x = (x - 0.5) * 2.0
        conv1 = F.relu(self.conv1_bn(self.conv1(x)))
        conv2 = F.relu(self.conv2_bn(self.conv2(conv1)))
        conv3 = F.relu(self.conv3_bn(self.conv3(conv2)))
        conv4 = F.relu(self.conv4_bn(self.conv4(conv3)))
        flat1 = torch.flatten(conv4.permute(0, 2, 3, 1), 1)
        logits = self.fc1_bn(self.fc1(flat1))
        return logits

    def forward(self, x):
        logits = self.get_logits(x)
        return F.log_softmax(logits, dim=1)
