import torch
import torch.nn as nn
import math
from torch_geometric.loader import NeighborLoader
from model.layer import SAGE_AGG, GCN_AGG, SGC_AGG, GAT_AGG, GIN_AGG
from tqdm import tqdm

class BaseGNN(nn.Module):
    def __init__(self, conv_layer: nn.Module, version: str, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int, dropout_ratio: float, activation: nn.Module, norm_type: str, is_batch: bool):
        super(BaseGNN, self).__init__()
        self.version = version
        self.num_layers = num_layers
        self.norm_type = norm_type
        self.dropout = nn.Dropout(dropout_ratio)
        self.activation = activation
        self.is_batch = is_batch
        # This is for when using random perturbatioin function
        # It means number of edges in train graph
        self.edge_num = nn.Parameter(torch.tensor(0.0), requires_grad=False) 
        
        self.aggs = nn.ModuleList([conv_layer(version, input_dim, hidden_dim if num_layers > 1 else output_dim, is_batch)])
        self.norms = nn.ModuleList([self._get_norm_layer(hidden_dim if num_layers > 1 else output_dim)])
        
        for _ in range(num_layers - 2):
            self.aggs.append(conv_layer(version, hidden_dim, hidden_dim, is_batch))
            self.norms.append(self._get_norm_layer(hidden_dim))
        
        if num_layers > 1:
            self.aggs.append(conv_layer(version, hidden_dim, output_dim, is_batch, last=True))

    def reset_parameters(self):
        for agg in self.aggs:
            agg.reset_parameters()

    def _get_norm_layer(self, dim):
        if self.norm_type == "batch":
            return nn.BatchNorm1d(dim)
        elif self.norm_type == "layer":
            return nn.LayerNorm(dim)
        return None
    
    def create_coeff(self, x, edge_index):
        if self.version !="base":
            # Only calculate the coefficient without constructing the full adjacency matrix
            num_nodes = x.size(0)
            num_edges = edge_index.size(1)
            # Add self-loops by simply accounting for them in the number of edges
            num_self_loops = num_nodes
            # Total number of edges including self-loops
            total_edges = num_edges + num_self_loops
            # Compute the coefficient based on the number of nodes and edges
            #coeff = num_nodes / total_edges # deterministic case
            coeff = (self.edge_num + num_nodes) / total_edges
        else:
            coeff = 0

        return coeff

    def forward(self, x, edge_index):
        h = x
        coeff = self.create_coeff(x, edge_index)
        for i, agg in enumerate(self.aggs):
            h = agg(h, edge_index, coeff)
            if i != self.num_layers - 1:
                if self.norm_type != "none":
                    h = self.norms[i](h)
                h = self.activation(h)
                h = self.dropout(h)
        
        return h
    
    # Efficient way of inference with all edges in large graph
    # Inspired by the implementation from pyg-team
    # https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ogbn_products_sage.py
    def inference_0(self, data, device, loader = None):
        # inference_0 only returns the final output for memory efficiency
        pbar = tqdm(total=data.x.size(0) * self.num_layers)
        pbar.set_description('Inference Full Graph')

        with torch.no_grad():
            x_all = data.x
            coeff = self.create_coeff(data.x, data.edge_index)
            for i, agg in enumerate(self.aggs):
                xs = []
                
                if self.is_batch:
                    # Batched version using NeighborLoader
                    for batch in loader:
                        x = x_all[batch.n_id].to(device)
                        edge_index = batch.edge_index.to(device)
                        x = agg(x, edge_index, coeff)
                        x = x[:batch.batch_size]

                        if i != self.num_layers - 1:
                            if self.norm_type != "none":
                                x = self.norms[i](x)
                            x = self.activation(x)
                        xs.append(x.cpu())
                        pbar.update(batch.batch_size)
                else:
                    # Full-batch version
                    edge_index = data.edge_index.to(device)
                    x = x_all.to(device)
                    x = agg(x, edge_index, coeff)
                    
                    if i != self.num_layers - 1:
                        if self.norm_type != "none":
                            x = self.norms[i](x)
                        x = self.activation(x)
                    
                    xs.append(x.cpu())
                    pbar.update(x_all.size(0))  # Update based on full batch size

                x_all = torch.cat(xs, dim=0)

        pbar.close()

        return x_all
    
    def inference(self, data, device):
        # inference returns the all intermediate representation matrices
        pbar = tqdm(total=data.x.size(0) * self.num_layers)
        pbar.set_description('Inference Full Graph')

        x_all = data.x
        h_all = [x_all]
        agg_all = []

        with torch.no_grad():
            coeff = self.create_coeff(data.x, data.edge_index)
            for i, agg in enumerate(self.aggs):
                xs = []
                aggs = []
                
                if self.is_batch:
                    # Batched version using NeighborLoader
                    subgraph_loader = NeighborLoader(
                        data,
                        input_nodes=None,
                        num_neighbors=[-1],
                        batch_size=8092,
                        shuffle=False)

                    for batch in subgraph_loader:
                        x = x_all[batch.n_id].to(device)
                        edge_index = batch.edge_index.to(device)
                        x = agg(x, edge_index, coeff)
                        x = x[:batch.batch_size]
                        aggs.append(x.cpu())

                        if i != self.num_layers - 1:
                            if self.norm_type != "none":
                                x = self.norms[i](x)
                            x = self.activation(x)
                        xs.append(x.cpu())
                        pbar.update(batch.batch_size)
                else:
                    # Full-batch version
                    edge_index = data.edge_index.to(device)
                    x = x_all.to(device)
                    x = agg(x, edge_index, coeff)

                    aggs.append(x.cpu())
                    
                    if i != self.num_layers - 1:
                        if self.norm_type != "none":
                            x = self.norms[i](x)
                        x = self.activation(x)
                    
                    xs.append(x.cpu())
                    pbar.update(x_all.size(0))  # Update based on full batch size

                agg_all.append(torch.cat(aggs, dim=0))
                x_all = torch.cat(xs, dim=0)
                h_all.append(x_all)

        pbar.close()

        return x_all, h_all, agg_all
    
    def agg_inference(self, data, embedding, edge_index, layer_index, device):
        coeff = self.create_coeff(data.x, data.edge_index)
        if self.is_batch:
            # Batched version 
            pbar = tqdm(total=data.x.size(0))
            pbar.set_description(f'{layer_index}th layer AGG')
            subgraph_loader = NeighborLoader(
                data,
                input_nodes=None,
                num_neighbors=[-1],
                batch_size=8092,
                shuffle=False)
            aggs = []
            with torch.no_grad():
                for batch in subgraph_loader:
                    x = embedding[batch.n_id].to(device)
                    edge_index = edge_index.to(device)
                    x = self.aggs[layer_index](x, edge_index, coeff)
                    x = x[:batch.batch_size]
                    aggs.append(x.cpu())
                    pbar.update(batch.batch_size)
            pbar.close()
            # Concatenate all aggs and return as a tensor
            agg_all = torch.cat(aggs, dim=0)

            return agg_all
        else:
            # Full-batch version
            with torch.no_grad():
                # Move embedding and edge index to the device
                x = embedding.to(device)
                edge_index = edge_index.to(device)
                # Perform aggregation using the AGG layer at {layer_index}
                x = self.aggs[layer_index](x, edge_index, coeff)
            return x.cpu()  # Return the result on CPU

class GraphSAGE(BaseGNN):
    def __init__(self, version, num_layers, input_dim, hidden_dim, output_dim, dropout_ratio, activation, norm_type, is_batch):
        super(GraphSAGE, self).__init__(SAGE_AGG, version, num_layers, input_dim, hidden_dim, output_dim, dropout_ratio, activation, norm_type, is_batch)

class GCN(BaseGNN):
    def __init__(self, version, num_layers, input_dim, hidden_dim, output_dim, dropout_ratio, activation, norm_type, is_batch):
        super(GCN, self).__init__(GCN_AGG, version, num_layers, input_dim, hidden_dim, output_dim, dropout_ratio, activation, norm_type, is_batch)

class GIN(BaseGNN):
    def __init__(self, version, num_layers, input_dim, hidden_dim, output_dim, dropout_ratio, activation, norm_type, is_batch):
        super(GIN, self).__init__(GIN_AGG, version, num_layers, input_dim, hidden_dim, output_dim, dropout_ratio, activation, norm_type, is_batch)

class SGC(BaseGNN):
    def __init__(self, version, num_layers, input_dim, hidden_dim, output_dim, dropout_ratio, activation, norm_type, is_batch):
        super(SGC, self).__init__(SGC_AGG, version, num_layers, input_dim, hidden_dim, output_dim, dropout_ratio, activation, norm_type, is_batch)

class GAT(BaseGNN):
    def __init__(self, version, num_layers, input_dim, hidden_dim, output_dim, dropout_ratio, activation, norm_type, is_batch):
        super(GAT, self).__init__(GAT_AGG, version, num_layers, input_dim, hidden_dim, output_dim, dropout_ratio, activation, norm_type, is_batch)
