import torch
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import torch.nn as nn
import numpy as np
import math
import torch.nn.functional as F


def count(seq, seq_nor, n, type):
    list1 = []
    # print(seq_nor.sum())
    if type:
        MAX = torch.max(seq_nor)
        MIN = torch.min(seq_nor)
    else:
        MAX = torch.max(seq)
        MIN = 0
    d = (MAX - MIN) / n
    if type == 0:
        for i in range(n):
            mask = (seq >= i * d) & (seq < (i + 1) * d)
            count = np.sum(mask[0].tolist())
            list1.append(float(count))
        return torch.tensor(list1)
    elif type == 1:
        for i in range(n):
            mask = (seq_nor >= MIN + i * d) & (seq_nor < MIN + (i + 1) * d)
            count = np.sum(torch.masked_select(seq_nor, mask).tolist())
            list1.append(float(count))
            # print(count)
        return torch.tensor(list1)
    if type == 2:
        for i in range(n):
            mask = (seq >= i * d) & (seq < (i + 1) * d)
            count = np.sum(mask[0].tolist())
            list1.append(float(count))
        list1 = softmax1(torch.tensor(list1), n)
        return torch.FloatTensor(list1).view(-1)
    else:
        for i in range(n):
            mask = (seq_nor >= MIN + i * d) & (seq_nor < MIN + (i + 1) * d)
            count = np.sum(torch.masked_select(seq, mask).tolist())
            list1.append(float(count))
        return torch.tensor(list1)


def softmax1(inputMatrix, n):
    # inputMatrix = inputMatrix.detach().numpy()
    outputMatrix = np.mat(np.zeros(n))
    soft_sum = 0
    for idx in range(n):
        outputMatrix[0, idx] = math.log10(inputMatrix[idx])
        soft_sum += outputMatrix[0, idx]
    for idx in range(n):
        outputMatrix[0, idx] = outputMatrix[0, idx]/soft_sum
    # outputMatrix = torch.from_numpy(outputMatrix)
    return outputMatrix


def softmax(inputMatrix):
    inputMatrix = inputMatrix.squeeze(0)
    inputMatrix = inputMatrix.detach().numpy()
    r, c = np.shape(inputMatrix)
    outputMatrix = np.mat(np.zeros((r, c)))
    soft_sum = 0
    for i in range(r):
        soft_sum = 0
        for idx in range(c):
            outputMatrix[i, idx] = math.exp(inputMatrix[i, idx])
            soft_sum += outputMatrix[i, idx]
        for idx in range(c):
            outputMatrix[i, idx] = outputMatrix[i, idx]/soft_sum
    outputMatrix = torch.from_numpy(outputMatrix)
    outputMatrix.unsqueeze(0)
    return outputMatrix


class NFGI(nn.Module):
    def __init__(self, args):
        self.args = args
        super(NFGI, self).__init__()

    def forward(self, seq):
        # print(seq)
        seq_nor = softmax(seq)
        if self.args.addVector:
            return torch.cat([torch.mean(seq, 1), count(seq, seq_nor, self.args.n, 0).unsqueeze(0)], dim=1)
        else:
            return torch.mean(seq, 1)