import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import circuit.var_config as vc

from gin.models.mlp import MLP
from models.layers import GraphConvolution
from torch_geometric.nn import global_add_pool, GATv2Conv, GCN2Conv
from utils.utils import normalize_adj


class Model(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, num_hops, num_mlp_layers,
                 dropout, model_flag, **kwargs):
        super(Model, self).__init__()
        self.num_node = (vc.num_gates+2)
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.num_layers = num_hops
        self.encoder = GINEncoder(self.input_dim, self.hidden_dim, self.latent_dim, self.num_layers, num_mlp_layers, model_flag=model_flag, **kwargs)
        self.decoder = Decoder(self.num_node, self.latent_dim, self.input_dim, dropout, model_flag, **kwargs)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, ops, adj):
        mu, logvar = self.encoder(ops, adj)
        z = self.reparameterize(mu, logvar)
        ops_recon, adj_recon = self.decoder(z)
        return ops_recon, adj_recon, mu, logvar

class GAE(nn.Module):
    def __init__(self, dims, normalize, reg_emb, reg_dec_l2, reg_dec_gp, dropout, encode_flag, **kwargs):
        super(GAE, self).__init__()
        self.encoder = Encoder(dims, normalize, reg_emb, dropout)
        self.decoder = Decoder(dims[-1], dims[0], dropout, encode_flag, **kwargs)
        self.reg_dec_l2 = reg_dec_l2
        self.reg_dec_gp = reg_dec_gp

    def forward(self, ops, adj):
        x, emb_loss = self.encoder(ops, adj)
        ops_recon, adj_recon = self.decoder(x)
        if self.reg_dec_l2:
            dec_loss_l2 = 0
            for p in self.decoder.parameters():
                dec_loss_l2 += torch.norm(p, 2)
            return ops_recon, adj_recon, emb_loss, dec_loss_l2, None
        if self.reg_dec_gp:
            return ops_recon, adj_recon, emb_loss, torch.FloatTensor([0.]).cuda(), x
        return ops_recon, adj_recon, emb_loss, torch.FloatTensor([0.]).cuda(), None

class GVAE(nn.Module):
    def __init__(self, dims, normalize, dropout, encode_flag, **kwargs):
        super(GVAE, self).__init__()
        self.encoder = VAEncoder(dims, normalize, dropout)
        self.decoder = Decoder(dims[-1], dims[0], dropout, encode_flag, **kwargs)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, ops, adj):
        mu, logvar = self.encoder(ops, adj)
        z = self.reparameterize(mu, logvar)
        ops_recon, adj_recon = self.decoder(z)
        return ops_recon, adj_recon, mu, logvar

class GINEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, num_hops, num_mlp_layers,
                model_flag = "gsqas", **kwargs) -> None:
        super(GINEncoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.num_layers = num_hops
        self.model_flag = model_flag
        self.eps = nn.Parameter(torch.zeros(self.num_layers))
        self.mlps = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        for layer in range(self.num_layers):
            if layer == 0:
                self.mlps.append(MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim))
            else:
                self.mlps.append(MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        #self.g_fc = nn.Linear(self.hidden_dim * num_hops, self.hidden_dim)
        if self.model_flag == "gsqas":
            self.fc1 = nn.Linear(self.hidden_dim, self.latent_dim)
            self.fc2 = nn.Linear(self.hidden_dim, self.latent_dim)
        else:
            self.fc_fusion = nn.Linear(self.hidden_dim * self.num_layers, self.hidden_dim)
            self.fc1 = nn.Linear(self.hidden_dim, self.latent_dim)
            self.fc2 = nn.Linear(self.hidden_dim, self.latent_dim)
        
    def forward(self, ops, adj):
        batch_size, node_num, opt_num = ops.shape
        x = ops
        agg_stack = []
        h = None
        #batch = torch.arange(batch_size).repeat_interleave(node_num).cuda()
        for l in range(self.num_layers):
            neighbor = torch.matmul(adj.float(), x)
            agg = (1 + self.eps[l]) * x.view(batch_size * node_num, -1) \
                  + neighbor.view(batch_size * node_num, -1)
            h = F.leaky_relu(self.batch_norms[l](self.mlps[l](agg)).view(batch_size, node_num, -1), negative_slope=0.01)
            agg_stack.append(h)
            #if l > 0 and l < (self.num_layers - 1) and self.model_flag in ["quantum_arch2vec", "quantum_arch2vec_with_degree"]:
            #    h = h + x
            x = h
            #g_global = global_add_pool(x.view(batch_size*node_num, -1), batch)
            #agg_global = torch.cat([agg_global, g_global], dim=-1)
        
        #agg_global = self.g_fc(agg_global)
        #agg_combined = torch.cat([agg_sum, agg_global.unsqueeze(1).expand(-1, node_num, -1)], dim=-1)
        
        if self.model_flag == "gsqas":
            agg_fusion = x
        else:
            agg_fusion = x
            agg_fusion = torch.cat(agg_stack, dim=-1)
            agg_fusion = self.fc_fusion(agg_fusion)
        
        mu = self.fc1(agg_fusion)
        logvar = self.fc2(agg_fusion)
        return mu, logvar

class GATEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, num_hops,
                model_flag = "gsqas", **kwargs) -> None:
        super(GATEncoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.num_layers = num_hops
        self.model_flag = model_flag
        self.gat = nn.ModuleList()
        
        for _ in range(self.num_layers):
            self.gat.append(GATv2Conv(input_dim, hidden_dim, heads=4))
    
    def forward(self, ops, adj):
        pass
        
class GCNEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, num_hops,
                model_flag = "gsqas", **kwargs) -> None:
        super(GCNEncoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.num_layers = num_hops
        self.model_flag = model_flag
        self.gcn = nn.ModuleList()
        
        for _ in range(self.num_layers):
            self.gcn.append(GCN2Conv()) #TODO

    def forward(self, ops, adj):
        pass


class Encoder(nn.Module):
    def __init__(self, dims, normalize, reg_emb, dropout):
        super(Encoder, self).__init__()
        self.gcs = nn.ModuleList(self.get_gcs(dims, dropout))
        self.normalize = normalize
        self.reg_emb = reg_emb

    def get_gcs(self,dims,dropout):
        gcs = []
        for k in range(len(dims)-1):
            gcs.append(GraphConvolution(dims[k],dims[k+1], dropout))
        return gcs

    def forward(self, ops, adj):
        if self.normalize:
            adj = normalize_adj(adj)
        x = ops
        for gc in self.gcs:
            x = gc(x, adj)
        if self.reg_emb:
            emb = x.mean(dim=1).squeeze()
            emb_loss = torch.mean(torch.norm(emb, p=2, dim=1))
            return x, emb_loss
        return x, torch.FloatTensor([0.]).cuda()

class VAEncoder(nn.Module):
    def __init__(self, dims, normalize, dropout):
        super(VAEncoder, self).__init__()
        self.gcs = nn.ModuleList(self.get_gcs(dims, dropout))
        self.gc_mu = GraphConvolution(dims[-2], dims[-1], dropout)
        self.gc_logvar = GraphConvolution(dims[-2], dims[-1], dropout)
        self.normalize = normalize

    def get_gcs(self,dims,dropout):
        gcs = []
        for k in range(len(dims)-1):
            gcs.append(GraphConvolution(dims[k],dims[k+1], dropout))
        return gcs

    def forward(self, ops, adj):
        if self.normalize:
            adj = normalize_adj(adj)
        x = ops
        for gc in self.gcs[:-1]:
            x = gc(x, adj)
        mu = self.gc_mu(x, adj)
        logvar = self.gc_logvar(x, adj)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, num_node, embedding_dim, input_dim, dropout, model_flag, activation_adj=torch.sigmoid, activation_ops=torch.sigmoid, activation_ops_qubits=torch.sigmoid, adj_hidden_dim=None, ops_hidden_dim=None):
        super(Decoder, self).__init__()
        if adj_hidden_dim == None:
            self.adj_hidden_dim = embedding_dim
        if ops_hidden_dim == None:
            self.ops_hidden_dim = embedding_dim
        self.num_node = num_node
        self.embedding_dim = embedding_dim
        self.input_dim = input_dim
        self.activation_adj = activation_adj
        self.activation_ops = activation_ops
        self.activation_ops_qubits = activation_ops_qubits
        self.model_flag = model_flag
        self.dropout = dropout
        if self.model_flag == 'gsqas':
            self.weight_ops = nn.Linear(embedding_dim, input_dim - vc.num_qubits)
            self.weight_ops_qubits = nn.Linear(embedding_dim, vc.num_qubits)
            self.weight_adj = nn.Linear(num_node, num_node)
        else:
            self.weight_ops = nn.Linear(embedding_dim, input_dim)
            self.weight_adj = nn.Linear(num_node, num_node)

    def forward(self, embedding):
        embedding = F.dropout(embedding, p=self.dropout, training=self.training)
        full_ops_decoding = None
        adj = None
        if self.model_flag in ['quantum_arch2vec', 'quantum_arch2vec_with_degree']:
            full_ops = self.weight_ops(embedding)
            ops = full_ops[:, :, :-vc.num_qubits]
            ops_qubits = full_ops[:, :, -vc.num_qubits:]
            full_ops_decoding = torch.cat((self.activation_ops(ops, dim=-1), self.activation_ops_qubits(ops_qubits)), dim=-1)
            adj_input = torch.matmul(embedding, embedding.permute(0, 2, 1))#.view(-1, self.num_node*self.num_node)
            adj = self.weight_adj(adj_input).view(-1, self.num_node, self.num_node)
        else:
            ops = self.weight_ops(embedding)
            ops_qubits = self.weight_ops_qubits(embedding)
            full_ops_decoding = torch.cat((self.activation_ops(ops, dim=-1), self.activation_ops_qubits(ops_qubits)), dim=-1)
            adj = self.weight_adj(torch.matmul(embedding, embedding.permute(0, 2, 1))).view(-1, self.num_node, self.num_node)
        return full_ops_decoding, self.activation_adj(adj)
    

class Reconstructed_Loss(object):
    def __init__(self, w_ops=1.0, w_adj=1.0, loss_ops=None, loss_adj=None):
        super().__init__()
        self.w_ops = w_ops
        self.w_adj = w_adj
        self.loss_ops = loss_ops
        self.loss_adj = loss_adj

    def __call__(self, inputs, targets):
        ops_recon, adj_recon = inputs[0], inputs[1]
        ops, adj = targets[0], targets[1]
        loss_ops = self.loss_ops(ops_recon, ops)
        loss_adj = self.loss_adj(adj_recon, adj)
        loss = self.w_ops * loss_ops + self.w_adj * loss_adj
        return loss


class VAEReconstructed_Loss(object):
    def __init__(self, w_ops=1.0, w_adj=1.0, loss_ops=None, loss_adj=None):
        super().__init__()
        self.w_ops = w_ops
        self.w_adj = w_adj
        self.loss_ops = loss_ops
        self.loss_adj = loss_adj

    def __call__(self, inputs, targets, mu, logvar):
        ops_recon, adj_recon = inputs[0], inputs[1]
        ops, adj = targets[0], targets[1]
        loss_ops = self.loss_ops(ops_recon, ops)
        loss_adj = self.loss_adj(adj_recon, adj)
        loss = self.w_ops * loss_ops + self.w_adj * loss_adj
        KLD = torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), -1)) / (ops.shape[0] * ops.shape[1])
        return loss, KLD
    
class DegreeConsistencyLoss(object):
    def __init__(self):
        super().__init__()
    
    def __call__(self, indegree, outdegree, adj_recon):
        actual_in_degree = adj_recon.sum(dim=-2)
        actual_out_degree = adj_recon.sum(dim=-1)
        in_degree_consitency_loss = F.mse_loss(actual_in_degree, indegree.squeeze(-1).float())
        out_degree_consitency_loss = F.mse_loss(actual_out_degree, outdegree.squeeze(-1).float())
        return 0.5 * (in_degree_consitency_loss + out_degree_consitency_loss)

'''
class Discriminator(nn.Module):
    def __init__(self, num_node, latent_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            # First layer
            nn.Linear(latent_dim, 128),
            nn.LayerNorm(128),
            nn.LeakyReLU(0.2, inplace=False),
            nn.Dropout(0.2),
            
            # Second layer
            nn.Linear(128, 64),
            nn.LayerNorm(64),
            nn.LeakyReLU(0.2, inplace=False),
            nn.Dropout(0.2),
            
            # output
            nn.Linear(64, 1),
        )
    
        # spectral_norm
        self.apply(self._add_spectral_norm)
    
    def _add_spectral_norm(self, module):
        if isinstance(module, nn.Linear):
            nn.utils.spectral_norm(module)

    def forward(self, z):
        batch_size, num_node, feature_dim = z.shape
        #z_ = z.view(batch_size, -1)
        return self.model(z)


# Bhattacharyya distance
def bhattacharyya_distance(p, q):
    """
    Compute the Bhattacharyya distance between two distributions p and q.
    
    Parameters:
    - p, q: Tensors of shape (batch_size, is_real)
            These could represent probabilities or some features of data samples.
    
    Returns:
    - bd: Scalar value representing the Bhattacharyya distance.
    """
    # Compute the mean and variance across batch_size
    p_mean = torch.mean(p, dim=0)  # Shape: (is_real,)
    q_mean = torch.mean(q, dim=0)
    p_var = torch.var(p, dim=0, unbiased=False) + 1e-10  # Shape: (is_real,)
    q_var = torch.var(q, dim=0, unbiased=False) + 1e-10

    # Calculate the coefficient for each feature
    var_mean = 0.5 * (p_var + q_var)
    mean_diff = (p_mean - q_mean) ** 2
    coeff = 0.25 * torch.log(0.25 * (p_var / q_var + q_var / p_var + 2))

    # Ensure numerical stability for the exponential term
    exp_input = -0.25 * mean_diff / var_mean
    exp_input_clamped = torch.clamp(exp_input, min=-10, max=10)  # Clamping to prevent overflow
    exp_term = torch.exp(exp_input_clamped)

    # Summing across all features to get a single scalar
    bd = torch.sum(coeff - torch.log(exp_term))
    return bd
'''