import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from gin.models.mlp import MLP
from models.layers import GraphConvolution
from utils.utils import preprocessing, normalize_adj

class GIN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_mlp_layers):
        super(GIN, self).__init__()
        self.num_layers = num_layers
        self.eps = nn.Parameter(torch.zeros(self.num_layers - 1))
        self.mlps = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(self.num_layers - 1):
            if layer == 0:
                self.mlps.append(MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim))
            else:
                self.mlps.append(MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

    def forward(self, x, edge_adj):
        batch_size, node_num, opt_num = x.shape
        for l in range(self.num_layers - 1):
            neighbor = torch.matmul(edge_adj.float(), x)
            agg = (1 + self.eps[l]) * x + neighbor
            x = F.relu(self.batch_norms[l](self.mlps[l](agg.view(batch_size * node_num, -1))).view(batch_size, node_num, -1))
        return x

class MultiChannelGINEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, num_hops, num_mlp_layers, num_edge_types, dropout):
        super(MultiChannelGINEncoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.num_layers = num_hops
        self.num_edge_types = num_edge_types

        self.gin_channels = nn.ModuleList([
            GIN(input_dim, hidden_dim, num_hops, num_mlp_layers) for _ in range(num_edge_types)
        ])

        self.fc1 = nn.Linear(self.hidden_dim * self.num_edge_types, self.latent_dim)
        self.fc2 = nn.Linear(self.hidden_dim * self.num_edge_types, self.latent_dim)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, ops, adjs):
        adjs = adjs.permute(1, 0, 2, 3)  # Permute to [num_edge_types, batch_size, adj_size, adj_size]
        batch_size, node_num, opt_num = ops.shape
        multi_channel_outs = []

        for edge_type in range(self.num_edge_types):
            edge_adj = adjs[edge_type]
            x = self.gin_channels[edge_type](ops, edge_adj)
            multi_channel_outs.append(x)

        multi_channel_outs = torch.cat(multi_channel_outs, dim=-1)
        mu = self.fc1(multi_channel_outs)
        logvar = self.fc2(multi_channel_outs)
        z = self.reparameterize(mu, logvar)
        return mu, logvar, z

# Multi-channel Decoder
class MultiChannelDecoder(nn.Module):
    def __init__(self, embedding_dim, input_dim, dropout, num_edge_types, activation_adj=torch.sigmoid, activation_ops=torch.sigmoid, adj_hidden_dim=None, ops_hidden_dim=None):
        super(MultiChannelDecoder, self).__init__()
        if adj_hidden_dim is None:
            adj_hidden_dim = embedding_dim
        if ops_hidden_dim is None:
            ops_hidden_dim = embedding_dim
        self.activation_adj = activation_adj
        self.activation_ops = activation_ops
        self.num_edge_types = num_edge_types
        self.dropout = dropout

        self.weight_ops = nn.Linear(embedding_dim, input_dim)
        self.weight_adjs = nn.ModuleList([nn.Linear(embedding_dim, embedding_dim) for _ in range(num_edge_types)])

    def forward(self, embedding):
        embedding = F.dropout(embedding, p=self.dropout, training=self.training)
        ops = self.weight_ops(embedding)
        adj_recon = []
        for i in range(self.num_edge_types):
            adj_hidden = self.weight_adjs[i](embedding)
            adj_recon.append(self.activation_adj(torch.matmul(adj_hidden, embedding.permute(0, 2, 1))))
        return self.activation_ops(ops), adj_recon


class MultiChannelModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, num_hops, num_mlp_layers,
                 dropout, num_edge_types, is_constrained=False, **kwargs):
        super(MultiChannelModel, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.num_layers = num_hops
        self.is_constrained = is_constrained

        self.encoder = MultiChannelGINEncoder(input_dim, hidden_dim, latent_dim, num_hops, num_mlp_layers, num_edge_types, dropout)
        self.decoder = MultiChannelDecoder(latent_dim, input_dim, dropout, num_edge_types=num_edge_types, **kwargs)

    def forward(self, ops, adjs):
        mu, logvar, z = self.encoder(ops, adjs)
        ops_recon, adj_recon = self.decoder(z)
        if self.is_constrained:
            pass # TODO: Masked ops_recon for the constrained model
        return ops_recon, adj_recon, mu, logvar

# VAE loss function
class VAEReconstructed_Loss(nn.Module):
    def __init__(self, w_ops=1.0, w_adj=1.0, loss_ops=None, loss_adj=None):
        super().__init__()
        self.w_ops = w_ops
        self.w_adj = w_adj
        self.loss_ops = loss_ops
        self.loss_adj = loss_adj

    def forward(self, inputs, targets, mu, logvar):
        ops_recon, adj_recon = inputs[0], inputs[1]
        ops, adjs = targets[0], targets[1]
        adjs = adjs.permute(1, 0, 2, 3)  # Permute to [num_edge_types, batch_size, adj_size, adj_size]

        loss_ops = self.loss_ops(ops_recon, ops)

        loss_adj = 0
        
        for recon_adj, adj in zip(adj_recon, adjs):
            loss_adj += self.loss_adj(recon_adj, adj)

        loss = self.w_ops * loss_ops + self.w_adj * loss_adj

        KLD = -0.5 / ops.size(0) * torch.mean(torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1))
        return loss + KLD