import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


class LSTM(nn.Module):
    def __init__(self, num_nodes, n_steps, node_info, **kwargs):

        super().__init__()
        self.num_nodes = num_nodes
        self.n_steps = n_steps
        self.device = kwargs['device']
        self.cat_nodes = node_info["cat_nodes"]
        self.cat_ranges = torch.tensor(node_info["cat_ranges"]).to(self.device)
        self.nume_nodes = node_info["nume_nodes"]
        self.bin_nodes = node_info["bin_nodes"]
        self.hidden_dim = 64
        out_dim = 64
        self.num_layers = 2
        self.dropout_rate = 0.2

        # shape [cat_nodes, feat_dim, max_cat]
        temp = np.expand_dims(self.cat_ranges.cpu().numpy(), [1])
        self.input_selector = torch.tensor(np.tile(temp, [1, 1, self.n_steps, 1])).to(self.device)
        self.input_flag = (self.input_selector.flatten() >= 0)

        self.input_dim = (np.sum((self.cat_ranges.cpu().numpy() >= 0).astype(np.int32)) + self.nume_nodes.shape[0] + self.bin_nodes.shape[0])
        self.lstm = nn.LSTM(input_size=self.input_dim, hidden_size=self.hidden_dim, num_layers=self.num_layers, batch_first=True)
        self.hidden2out = nn.Linear(self.hidden_dim, out_dim)

        self.dense_cat = torch.nn.Linear(out_dim, np.sum((self.cat_ranges.cpu().numpy() >= 0).astype(np.int32)))
        self.dense_nume = torch.nn.Linear(out_dim, self.nume_nodes.shape[0])
        self.dense_bin = torch.nn.Linear(out_dim, self.bin_nodes.shape[0])

    def init_hidden(self, batch_size):
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device)
        return (h0,c0)

    def forward(self, data, **kwargs):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=self.dropout_rate)
        x = x.reshape([-1, self.num_nodes, self.n_steps])
        batch_size = x.shape[0]

        #if there exist each feature type
        feats = []
        if self.nume_nodes.shape[0]!=0:
            # nume_feat = x[:, self.nume_nodes, :]
            nume_feat = x[:, self.nume_nodes, :].reshape([batch_size, -1])
            feats.append(nume_feat)

        if self.cat_nodes.shape[0]!=0:
            # [batch, cat_nodes, concat_steps, 1]
            cat_x = x[:, self.cat_nodes, :].unsqueeze(3)
            onehot_inputs = (cat_x == self.input_selector)
            cat_feat = onehot_inputs.reshape([batch_size, -1])[:, self.input_flag]
            feats.append(cat_feat)

        if self.bin_nodes.shape[0]!=0:
            # [batch, bin_nodes*concat_steps]
            bin_feat = x[:, self.bin_nodes, :].reshape([batch_size, -1])
            feats.append(bin_feat)

        feat_input = torch.cat(feats, axis=1)
        feat_input = feat_input.reshape([batch_size, self.n_steps, -1])
        self.hidden = self.init_hidden(batch_size)
        hidden, _ = self.lstm(feat_input, self.hidden)
        out = self.hidden2out(hidden[:,-1,:])
        xout = F.dropout(out,p =self.dropout_rate)
        pred = dict()
        # prediction for numerical nodes
        if self.nume_nodes.shape[0]!=0:
            numerical_predict =  self.dense_nume(xout).unsqueeze(2)
            pred['numerical'] = numerical_predict

        if self.cat_nodes.shape[0]!=0:
            # predict categorical labels
            sparse_logits = self.dense_cat(xout)

            logits = -1000 * torch.ones([batch_size, self.cat_ranges.cpu().numpy().size], dtype=torch.float32, device=self.device)

            logits[:, self.cat_ranges.flatten() >= 0] = sparse_logits
            logits = logits.reshape([batch_size, self.cat_ranges.shape[0], self.cat_ranges.shape[1]])

            pred['categorical'] = logits

        if self.bin_nodes.shape[0]!=0:
            # predictions for binary nodes
            binary_predict = self.dense_bin(xout).unsqueeze(2)
            pred['binary'] = binary_predict
        return pred

    def to(self, device):
        super().to(device)
        self.input_selector = self.input_selector.to(device)
        self.device = device

    def get_additional_loss_terms(self):
        return 0