import itertools
import random
import time

import numpy as np
import tensorboard_logger as tf_logger
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

np.set_printoptions(threshold=1000000)

COMPONENT_NUM = 2
FEATURE_DIM = 1
FLOAT_MIN = -3.4e38


def generate_dataset(feature_num):
    n_components = []
    for i in range(COMPONENT_NUM):
        n_components.append(range(1, feature_num + 1))
    X = list(itertools.product(*n_components))

    distinct_elements = set()
    for ele in X:
        print(ele)
        exit()
        
    X = np.asarray(X)
    # np.random.shuffle(X)

    X = th.from_numpy(X).float()
    Y = X.prod(dim=-1, keepdim=True)
    print("Label:", Y)
    print("Input:", X)
    print("Data size:", len(X))

    return X, Y


class DeepSet(nn.Module):
    def __init__(self, hidden_num=32):
        super(DeepSet, self).__init__()
        self.hidden_num = hidden_num
        self.embedding = nn.Sequential(
            nn.Linear(FEATURE_DIM, hidden_num)
        )
        self.prediction = nn.Sequential(
            nn.ReLU(),
            nn.Linear(hidden_num, 1),
        )

    def forward(self, X, print_log=False):
        """
        :param X: [bs, 2]
        :return:
        """
        # [bs, 2] -> [bs * 2, 1]
        reshaped_X = X.view(-1, FEATURE_DIM)
        # [bs, 2, hidden_num] -> [bs, hidden_num]
        embedding = self.embedding(reshaped_X).view([*X.shape, self.hidden_num]).sum(-2)
        return self.prediction(embedding)


class MLP(nn.Module):
    def __init__(self, hidden_num=32):
        super(MLP, self).__init__()
        self.prediction = nn.Sequential(
            nn.Linear(COMPONENT_NUM * FEATURE_DIM, hidden_num),
            nn.ReLU(),
            nn.Linear(hidden_num, 1),
        )

    def forward(self, X, print_log=False):
        """
        :param X: [bs, 2]
        :return:
        """
        prediction = self.prediction(X)
        return prediction


class DPN(nn.Module):
    def __init__(self, hidden_num=32):
        super(DPN, self).__init__()
        self.prediction = nn.Sequential(
            nn.Linear(COMPONENT_NUM * FEATURE_DIM, hidden_num),
            nn.ReLU(),
            nn.Linear(hidden_num, 1),
        )
        self.permutation_net = nn.Sequential(
            nn.Linear(FEATURE_DIM, hidden_num),
            nn.ReLU(),
            nn.Linear(hidden_num, COMPONENT_NUM),
        )

    def forward(self, X, print_log=False):
        """
        :param X: [bs, 2]
        :return:
        """
        # (1) %%%%%%%%%%%%%%%%%% compute assignment matrix %%%%%%%%%%%%%%%%%%
        expanded_X = X.unsqueeze(dim=-1)  # [bs, n_component, 1]
        logits = self.permutation_net(expanded_X)  # [bs, n_component, 2]
        _invalid_position_mask = th.zeros(X.shape)  # [bs, n_component]  # 判断哪些components已经被selected了
        permutation_matrices = []
        # For each column, compute the assignment probability
        for position_idx in range(COMPONENT_NUM):  # over n_position, the position order (learned order) is stable
            logit = logits[:, :, position_idx]  # [bs, 2]  column order
            selection_prob = self.get_selection_prob(logit, _invalid_position_mask)  # [bs, 2]
            _invalid_position_mask = _invalid_position_mask + selection_prob.detach()
            permutation_matrices.append(selection_prob)
        permutation_matrices = th.stack(permutation_matrices, dim=1)  # [bs, n_position, n_components]

        # (2) %%%%%%%%%%%%%%%%%% permute the order of the input %%%%%%%%%%%%%%%%%%
        # [bs, n_position, n_components] -> [bs, n_component, 1] = [bs, n_position, 1]
        permuted_fea = th.matmul(permutation_matrices, expanded_X).squeeze(dim=-1)
        if print_log:
            print(permutation_matrices.detach().numpy())
            print(permuted_fea.detach().numpy())

        # (3) %%%%%%%%%%%%%%%%%% feed into the output layer %%%%%%%%%%%%%%%%%%
        return self.prediction(permuted_fea)

    @staticmethod
    def onehot_from_logits(logits, dim):
        """
        Given batch of logits, return one-hot sample using epsilon greedy strategy
        (based on given epsilon)
        """
        # get best (according to current policy) actions in one-hot form
        index = logits.max(dim, keepdim=True)[1]
        y_hard = th.zeros_like(logits).scatter_(dim, index, 1.0)
        return y_hard

    def get_selection_prob(self, logit, invalid_components):
        negative_mask = invalid_components * FLOAT_MIN
        logit_masked = logit + negative_mask  # [bs, 2]

        return self.onehot_from_logits(logit_masked, dim=-1)
        # return F.gumbel_softmax(logit_masked, tau=0.1, hard=True, dim=-1)

    def parameters(self, recurse=True):
        return self.prediction.parameters()

        # # init.kaiming_uniform_(self.weights, a=math.sqrt(5))
        #
        # gain = 1
        # std = gain / math.sqrt(2)
        # bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
        # with th.no_grad():
        #     init.uniform_(self.weights, -bound, bound)
        # if self.bias is not None:
        #     fan_in = 2
        #     bound = 1 / math.sqrt(fan_in)
        #     init.uniform_(self.bias, -bound, bound)


class InfiniteWeightNet(nn.Module):
    def __init__(self, feature_number, hidden_num=32):
        super(InfiniteWeightNet, self).__init__()
        self.weights = Parameter(th.randn([feature_number, hidden_num], dtype=th.float32))
        self.bias = Parameter(th.randn([1, hidden_num], dtype=th.float32))

        self.prediction = nn.Sequential(
            nn.ReLU(),
            nn.Linear(hidden_num, 1),
        )

    def forward(self, X, print_log=False):
        """
        :param X: [bs, 2]
        :return:
        """
        bs, number = X.shape
        # Gather the corresponding weight
        index_X = X.view(-1).long() - 1  # [bs * 2]
        corresponding_W = self.weights[index_X].view([bs, number, -1])  # [bs * 2]
        # corresponding_W = th.gather(self.weights, dim=0, index=index_X).view(X.shape)  # [bs, 2]
        embedding = th.sum(X.unsqueeze(-1) * corresponding_W, dim=-2) + self.bias

        prediction = self.prediction(embedding)
        return prediction


class HyperLinear(nn.Module):
    def __init__(self, input_dim, output_dim, hyper_hidden_size, bias=True):
        super(HyperLinear, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.hyper_w = nn.Sequential(
            nn.BatchNorm1d(num_features=input_dim),
            nn.Linear(input_dim, hyper_hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hyper_hidden_size, hyper_hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hyper_hidden_size, 1 * hyper_hidden_size),
        )
        if bias:
            self.bias = Parameter(th.Tensor(1, 1, output_dim).fill_(0.))
        else:
            self.bias = 0


    def forward(self, x):
        bs, n_agent, fea_dim = x.shape
        reshaped_X = X.view(-1, fea_dim)  # [bs * 2, 1]
        # [batch_size, self.n_agent, input_dim] -> [batch_size, self.n_agent, input_dim, output_dim]
        weights = self.hyper_w(reshaped_X).view(bs, n_agent, self.input_dim, self.output_dim)
        embedding = th.matmul(x.unsqueeze(2), weights).squeeze(2) + self.bias
        embedding = embedding.sum(dim=1)  # [bs, 2, hidden_num]  -> [bs, hidden_num]
        return embedding  # [batch_size, self.n_agent, output_dim]


class Hypernet(nn.Module):
    def __init__(self, hidden_num=32):
        super(Hypernet, self).__init__()
        self.hidden_num = hidden_num

        self.prediction = nn.Sequential(
            HyperLinear(input_dim=1, output_dim=hidden_num, hyper_hidden_size=hidden_num, bias=True),
            nn.ReLU(),
            nn.Linear(hidden_num, 1),
        )

    def forward(self, X, print_log=False):
        """
        :param X: [bs, 2]
        :return:
        """
        return self.prediction(X.unsqueeze(dim=-1))


class SelfAttention(nn.Module):
    def __init__(self, hidden_num=32):
        super(SelfAttention, self).__init__()
        self.hidden_num = hidden_num

        self.token = nn.Sequential(
            nn.BatchNorm1d(1),
            nn.Linear(1, self.hidden_num)
        )

        self.Query = nn.Sequential(
            nn.Linear(self.hidden_num, self.hidden_num)
        )
        self.Key = nn.Sequential(
            nn.Linear(self.hidden_num, self.hidden_num)
        )
        self.Value = nn.Sequential(
            nn.Linear(self.hidden_num, self.hidden_num)
        )
        self.layer_norm = nn.LayerNorm(self.hidden_num)

        self.prediction = nn.Sequential(
            nn.ReLU(),
            nn.Linear(hidden_num, 1),
        )

    def forward(self, X, print_log=False):
        """
        :param X: [bs, 2]
        :return:
        """

        # expanded_X = X.unsqueeze(-1)  # [bs, 2] -> [bs, 2, 1]
        reshaped_X = X.view(-1, 1)  # [bs, 2] -> [bs * 2, 1]
        tokens = self.token(reshaped_X).view([*X.shape, self.hidden_num])  # [bs, 2, hidden_num]

        # %%%%%%%%%%%%%%%%%%%%%%%%%% Self Attention %%%%%%%%%%%%%%%%%%%%%%%%%%
        query = self.Query(tokens)  # [bs, 2, hidden_num]
        key = self.Key(tokens)  # [bs, 2, hidden_num]
        value = self.Value(tokens)  # [bs, 2, hidden_num]

        query = query / (self.hidden_num ** (1 / 4))
        key = key / (self.hidden_num ** (1 / 4))
        # - Instead of dividing the dot products by sqrt(e), we scale the keys and values.
        #   This should be more memory efficient
        # - get dot product of queries and keys, and scale
        dot = th.bmm(query, key.transpose(1, 2))
        dot = F.softmax(dot, dim=2)  # [bs, 2, 2]
        # - dot now has row-wise self-attention probabilities
        # apply the self attention to the values
        attended_out = th.bmm(dot, value)  # [bs, 2, hidden_num]
        embedding = attended_out + tokens
        embedding = self.layer_norm(embedding)

        # Aggregate
        embedding = embedding.sum(dim=1)

        return self.prediction(embedding)


class GraphConvLayer(nn.Module):
    """Implements a GCN layer."""

    def __init__(self, input_dim, output_dim):
        super(GraphConvLayer, self).__init__()
        self.lin_layer = nn.Linear(input_dim, output_dim)
        self.input_dim = input_dim
        self.output_dim = output_dim

    def forward(self, input_feature, input_adj):
        feat = self.lin_layer(input_feature)
        out = th.matmul(input_adj, feat)  # [N, N], [bs, N, 4]
        return out

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.input_dim) + ' -> ' \
               + str(self.output_dim) + ')'


class GNN(nn.Module):
    def __init__(self):
        super(GNN, self).__init__()
        self.hidden_num = 4
        self.gc1 = GraphConvLayer(1, self.hidden_num)
        self.linear = nn.Linear(1, self.hidden_num)
        self.V = nn.Linear(self.hidden_num, 1)
        self.register_buffer('adj', (th.ones(COMPONENT_NUM, COMPONENT_NUM) - th.eye(COMPONENT_NUM)))

    def forward(self, X, print_log=False):
        """
        :param X: [bs, 2]
        :return:
        """
        expanded_X = X.unsqueeze(-1)  # [bs, 2] -> [bs, 2, 1]

        feat = F.relu(self.gc1(expanded_X, self.adj), inplace=True)
        feat = feat + F.relu(self.linear(expanded_X), inplace=True)
        feat /= COMPONENT_NUM

        # ret, _ = feat.max(1)  # Pooling over the agent dimension.
        ret = feat.sum(1)  # Pooling over the agent dimension.
        prediction = self.V(ret)
        return prediction


if __name__ == '__main__':
    seed = 75981140
    random.seed(seed)
    np.random.seed(seed)
    th.manual_seed(seed)
    th.cuda.manual_seed(seed)
    # th.cuda.manual_seed_all(seed)
    th.backends.cudnn.deterministic = True  # cudnn

    n_epoch = 2500
    feature_num = 30
    data_X, data_Y = generate_dataset(feature_num)

    t = time.time()
    # model = DeepSet()
    # model = MLP()
    # model = DPN()
    # model = InfiniteWeightNet(feature_num)
    model = Hypernet()
    # model = SelfAttention()

    optimizer = th.optim.Adam(params=model.parameters(), lr=0.05)
    logger = tf_logger.Logger(logdir="./log/{}/{}".format(type(model).__name__, t), flush_secs=0.1)
    for epoch in range(n_epoch):
        # X, Y = sampler.sample()
        X, Y = data_X, data_Y
        predicted_Y = model.forward(X, print_log=epoch == n_epoch - 1)
        loss = th.mean((Y - predicted_Y) ** 2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0:
            print("Iter={}: MSE loss={}".format(epoch, loss.item()))
            logger.log_value("MSE", loss.item(), step=epoch)
