
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn.parameter import Parameter






from .util import softmax

def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot


class Graph_Generator(nn.Module):

    def __init__(self,node_num):
        super(Graph_Generator, self).__init__()

        self.embedding_dim = 128
         
        self.conv1 = torch.nn.Conv2d(1, 8, (1,2),stride=1,dilation=2)
        self.conv2 = torch.nn.Conv2d(8, 16,(1,3),stride=1,dilation=2)
        self.conv3 = torch.nn.Conv2d(16, 32,(1,3),stride=1,dilation=2)


        self.fc = torch.nn.Linear(64, self.embedding_dim) 
       
        self.bn1 = torch.nn.BatchNorm2d(8)
        self.bn2 = torch.nn.BatchNorm2d(16)
        self.bn3 = torch.nn.BatchNorm2d(32)
        self.bn4 = torch.nn.BatchNorm1d(self.embedding_dim)
        self.bn5 = torch.nn.BatchNorm1d(2)
        

        self.fc_out = nn.Linear(self.embedding_dim * 2, self.embedding_dim)
        self.fc_cat = nn.Linear(self.embedding_dim, 2)

        self.fc_continuous_1 = nn.Linear(self.embedding_dim,64)
        self.fc_continuous_2 = nn.Linear(64,1)

        off_diag = np.ones([node_num, node_num])
        rel_rec = torch.tensor(np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32),requires_grad=False)
        rel_send = torch.tensor(np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32),requires_grad=False)
        self.rel_rec = rel_rec.cuda()
        self.rel_send = rel_send.cuda()

        
     
    def sample(self,x,prior_form,hard=False):
        
           
        batch_size,node_num,horizon = x.shape
        x = x.view(batch_size,1,node_num,-1)
       
        x = self.conv1(x)
        x = F.relu(x)
        x = self.bn1(x)
     
        x = self.conv2(x)
        x = F.relu(x)
        x = self.bn2(x)
    
        x = self.conv3(x)
        x = F.relu(x)
        x = self.bn3(x)  

        x = x.view(batch_size,node_num, -1)
        x = self.fc(x)
        x = F.relu(x)
        x = x.view(batch_size,-1,node_num)
        x = self.bn4(x)
        x = x.view(batch_size,node_num,-1)

     
        receivers = torch.matmul(self.rel_rec, x)
        senders = torch.matmul(self.rel_send, x)

        x = torch.cat([senders, receivers], dim=-1)
        x = torch.relu(self.fc_out(x))
        
        
        if prior_form == 'binary':

            x = self.fc_cat(x)
            x = torch.permute(x, (0,2,1))
            x = self.bn5(x)
            x = torch.permute(x, (0,2,1))
            logp = x 
            out = softmax(logp, hard)
            out_matrix = out[:,:,0].view(batch_size,node_num, node_num)
        
        
        return out_matrix




