import torch
import dgl
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from dgl.nn.pytorch.conv import GATConv, SAGEConv, GraphConv
from quattention import QuantumAttentionHead, MultiHeadQuantumAttention, QuantumAttentionHeadCustom


class CGraphClassification(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes, n_heads=4):
        super(CGraphClassification, self).__init__()
        # EDIT(2) to use our QSAGEConv layer
        self.conv1 = GATConv(in_feats, h_feats, n_heads)
        self.conv2 = GATConv(h_feats*n_heads, h_feats, n_heads)
        #self.conv3 = GATConv(h_feats*n_heads, h_feats, n_heads)
        #self.conv4 = GATConv(h_feats*n_heads, h_feats, n_heads) 
        self.linear = nn.Linear(h_feats*n_heads, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = h.transpose(2, 1).reshape((h.shape[0], -1))
        h = F.relu(h)
        h = self.conv2(g, h)
        h = h.transpose(2, 1).reshape((h.shape[0], -1))
        h = F.relu(h)
        #h = self.conv3(g, h)
        #h = h.transpose(2, 1).reshape((h.shape[0], -1))
        #h = F.relu(h)
        #h = self.conv4(g, h)
        #h = h.transpose(2, 1).reshape((h.shape[0], -1))
        #h = F.relu(h)
        g.ndata['h'] = h
        h = dgl.mean_nodes(g, 'h')            
        return self.linear(h)
    
    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.linear.reset_parameters()
        
        
class GCNRegression(nn.Module):
    def __init__(self, in_feats, h_feats, num_targets, n_layers=4):
        super(GCNRegression, self).__init__()
        # EDIT(2) to use our QSAGEConv layer
        self.conv1 = GraphConv(in_feats, h_feats, allow_zero_in_degree=True)
        self.hidden_conv = nn.ModuleList([GraphConv(h_feats, h_feats, allow_zero_in_degree=True) for _ in range(n_layers)])
        self.linear = nn.Linear(h_feats, num_targets)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        for conv in self.hidden_conv:
            h = conv(g, h)
            h = F.relu(h)
        g.ndata['h'] = h
        h = dgl.mean_nodes(g, 'h')            
        return self.linear(h)
    
    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.hidden_conv:
            conv.reset_parameters()
        self.linear.reset_parameters()

class CGraphRegression(nn.Module):
    def __init__(self, in_feats, h_feats, num_targets, n_heads=4, n_layers=4):
        super(CGraphRegression, self).__init__()
        # EDIT(2) to use our QSAGEConv layer
        self.conv1 = GATConv(in_feats, h_feats, n_heads)
        self.hidden_conv = nn.ModuleList([GATConv(h_feats*n_heads, h_feats, n_heads) for _ in range(n_layers)])
        self.linear = nn.Linear(h_feats*n_heads, num_targets)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = h.transpose(2, 1).reshape((h.shape[0], -1))
        h = F.relu(h)
        for conv in self.hidden_conv:
            h = conv(g, h)
            h = h.transpose(2, 1).reshape((h.shape[0], -1))
            h = F.relu(h)
        g.ndata['h'] = h
        h = dgl.mean_nodes(g, 'h')            
        return self.linear(h)
    
    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.hidden_conv:
            conv.reset_parameters()
        self.linear.reset_parameters()

class QGraphNetworkCustom(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes, n_att_layers, observables, apply_softmax=False, only_neighbors=False):
        super(QGraphNetworkCustom, self).__init__()
        # EDIT(2) to use our QSAGEConv layer
        self.conv1 = QuantumAttentionHeadCustom(
            in_feats,
            h_feats,
            n_att_layers,
            observables,
            apply_softmax,
            only_neighbors)
        self.conv2 = QuantumAttentionHeadCustom(
            h_feats,
            h_feats,
            n_att_layers,
            observables,
            apply_softmax,
            only_neighbors)
        self.linear = nn.Linear(h_feats, num_classes)

    def forward(self,
                in_feat,
                ising_matrices,
                adjacency,
                batch_size=1,
                precomputed_attention1=None,
                precomputed_attention2=None,
                observables=None):
        h = self.conv1(in_feat,
                       ising_matrices,
                       adjacency,
                       batch_size,
                       precomputed_attention1,
                       observables=observables)
        h = F.relu(h)
        h = self.conv2(h,
                       ising_matrices,
                       adjacency,
                       batch_size,
                       precomputed_attention2,
                       observables=observables)
        h = F.relu(h)
        h = torch.mean(h, axis=1)           
        return self.linear(h)
    
    def update_observable_device(self, device):
        self.conv1.update_observable_device(device)
        self.conv2.update_observable_device(device)
    
    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.linear.reset_parameters()


class QGraphNetworkCustomParallel(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes, n_att_layers, observables, dev0, dev1, apply_softmax=False, only_neighbors=False):
        super(QGraphNetworkCustomParallel, self).__init__()
        # EDIT(2) to use our QSAGEConv layer
        self.conv1 = QuantumAttentionHeadCustom(
            in_feats,
            h_feats,
            n_att_layers,
            observables,
            apply_softmax,
            only_neighbors).to(dev0)
        self.conv2 = QuantumAttentionHeadCustom(
            h_feats,
            h_feats,
            n_att_layers,
            observables,
            apply_softmax,
            only_neighbors).to(dev1)
        self.dev0 = dev0
        self.dev1 = dev1
        self.linear = nn.Linear(h_feats, num_classes).to(dev1)

    def forward(self,
                in_feat,
                ising_matrices,
                adjacency,
                batch_size=1,
                precomputed_attention1=None,
                precomputed_attention2=None,
                observables=None):
        if precomputed_attention1 is not None:
            precomputed_attention1 = precomputed_attention1.to(self.dev0)
        if precomputed_attention2 is not None:
            precomputed_attention2 = precomputed_attention2.to(self.dev1)
        if ising_matrices is not None:
            ising_matrices = ising_matrices.to(self.dev0)
        h = self.conv1(in_feat.to(self.dev0),
                       ising_matrices,
                       adjacency.to(self.dev0),
                       batch_size,
                       precomputed_attention1,
                       observables=observables)
        h = F.relu(h)
        if ising_matrices is not None:
            ising_matrices = ising_matrices.to(self.dev1)
        h = self.conv2(h.to(self.dev1),
                       ising_matrices,
                       adjacency.to(self.dev1),
                       batch_size,
                       precomputed_attention2,
                       observables=observables)
        h = F.relu(h)
        h = torch.mean(h, axis=1)           
        return self.linear(h)
    
    def update_observable_device(self, device):
        self.conv1.update_observable_device(device)
        self.conv2.update_observable_device(device)
    
    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.linear.reset_parameters()

    def move_parameters_to_device(self):
        self.conv1 = self.conv1.to(self.dev0)
        self.conv2 =  self.conv2.to(self.dev1)
        self.linear = self.linear.to(self.dev1)


class QGraphClassification(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes, n_att_layers, observables, apply_softmax=False, only_neighbors=False):
        super(QGraphClassification, self).__init__()
        # EDIT(2) to use our QSAGEConv layer
        self.conv1 = QuantumAttentionHead(
            in_feats,
            h_feats,
            n_att_layers,
            observables,
            apply_softmax,
            only_neighbors)
        self.conv2 = QuantumAttentionHead(
            h_feats,
            h_feats,
            n_att_layers,
            observables,
            apply_softmax,
            only_neighbors)        
        self.linear = nn.Linear(h_feats, num_classes)

    def forward(self,
                g,
                in_feat,
                ising_matrices,
                unbatch=True,
                batch_size=1,
                precomputed_attention1=None,
                precomputed_attention2=None,
                observables=None):
        h = self.conv1(g,
                       in_feat,
                       ising_matrices,
                       unbatch,
                       batch_size,
                       precomputed_attention1,
                       observables=observables)
        h = F.relu(h)
        h = self.conv2(g,
                       h,
                       ising_matrices,
                       unbatch,
                       batch_size,
                       precomputed_attention2,
                       observables=observables)
        h = F.relu(h)
        g.ndata['h'] = h
        h = dgl.mean_nodes(g, 'h')            
        return self.linear(h)
    
    def update_observable_device(self, device):
        self.conv1.update_observable_device(device)
        self.conv2.update_observable_device(device)
    
    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.linear.reset_parameters()
        
class MultiHeadQGraphClassification(nn.Module):
    def __init__(self, in_feats, out_head, n_heads, num_classes, n_att_layers, observables, apply_softmax=False, only_neighbors=False):
        super(MultiHeadQGraphClassification, self).__init__()
        h_feats = out_head*n_heads
        self.conv1 = MultiHeadQuantumAttention(in_feats, out_head, n_heads, n_att_layers, observables, apply_softmax, only_neighbors)
        self.conv2 = MultiHeadQuantumAttention(h_feats, out_head, n_heads, n_att_layers, observables, apply_softmax, only_neighbors)        
        self.linear = nn.Linear(h_feats, num_classes)

    def forward(self, g, in_feat, ising_matrices, unbatch=True):
        h = self.conv1(g, in_feat, ising_matrices, unbatch)        
        h = F.relu(h)
        h = self.conv2(g, h, ising_matrices, unbatch)        
        h = F.relu(h)
        g.ndata['h'] = h
        h = dgl.mean_nodes(g, 'h')            
        return self.linear(h)
    
    def update_observable_device(self, device):
        self.conv1.update_observable_device(device)
        self.conv2.update_observable_device(device)
    
    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.linear.reset_parameters()
        