import torch
from torch import nn
from torch.nn import functional as F
from time import time
import torch_geometric.nn as gnn
from collections import defaultdict, deque
import pytorch_lightning as pl
from .mlp import MLP
import torch_scatter
from torch_sparse import SparseTensor

class ABGNN(nn.Module):
    def __init__(
        self,
        hidden_dim,    # dim of the hidden layers
        out_dim,       # dim of the last layer
        dropout,       # dropout rate
        n_layers=None,  # number of layers
        in_dim=16,      # dim of the input layer
        activation=torch.relu, #activation function
    ):
        super(ABGNN, self).__init__()
        self.activation = activation
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim
        self.n_layers = n_layers
        self.dropout = nn.Dropout(p=dropout)
        self.layers = nn.ModuleList()
        self.fc_init = nn.Linear(in_dim,hidden_dim)
        in_dim = hidden_dim

        self.conv = gnn.SAGEConv(
            in_dim,
            hidden_dim,
            aggr='add',
        )

    def forward(self, features, edge_index, forward_level):
        depth = forward_level.max()+1
        h = self.activation(self.fc_init(features))
        for i in range(depth):
            if i != 0:
                h = self.dropout(h)
            edge_i = edge_index[:, forward_level[edge_index[1]] == i]
            h = self.conv(h, edge_i) # the generated node embeddings of current layer
            if i != depth - 1:
                h = self.activation(h)
        return h.squeeze(1)
    

class BoundaryABGNN(pl.LightningModule):

    def __init__(self,args):
        super().__init__()  
        self.args = args
        self.aig_encoder = ABGNN(args.hidden, args.hidden, in_dim=3, dropout=0.2)
        self.pm_encoder = ABGNN(args.hidden, args.hidden, in_dim=64, dropout=0.2)
        self.save_hyperparameters()
        self.training_step_outputs = []
        self.test_step_outputs = []
        self.val_step_outputs = []
        self.mlp= MLP(2*self.args.hidden, 2*self.args.hidden, 3, num_layer=3)

    def find_nodes_between_start_and_end(self, edge_index, start_nodes, end_nodes):
        device = edge_index.device
        num_nodes = int(edge_index.max().item()) + 1
        adj = SparseTensor(
            row=edge_index[0],
            col=edge_index[1],
            sparse_sizes=(num_nodes, num_nodes)
        )
        start_nodes = torch.as_tensor(start_nodes, device=device).unique()
        end_nodes = torch.as_tensor(end_nodes,   device=device).unique()
        end_set = set(end_nodes.tolist())

        crow, col,_ = adj.csr()  # crow: (num_nodes+1,), col: (num_edges,)
        visited = torch.zeros(num_nodes, dtype=torch.bool, device=device)
        visited[start_nodes] = True
        queue = start_nodes.clone()
        result_set = set(start_nodes.tolist())

        while queue.numel() > 0:
            current = int(queue[0].item())
            queue = queue[1:]

            if current in end_set:
                continue

            start_ptr, end_ptr = int(crow[current].item()), int(crow[current + 1].item())
            neighbors = col[start_ptr:end_ptr]

            unvisited = neighbors[~visited[neighbors]]
            if unvisited.numel() == 0:
                continue

            visited[unvisited] = True
            result_set.update(unvisited.tolist())
            queue = torch.cat([queue, unvisited])

        return torch.tensor(list(result_set), dtype=torch.long, device=device)

 
    def forward_boundary(self, batch, batch_idx):
        bs = batch.batch_size
        device = batch.pm_x.device

        #encode aig
        sub_aig_hf = self.aig_encoder(F.one_hot(batch.sub_aig_gate_type.squeeze(-1), num_classes=3).float(), batch.sub_aig_edge_index,batch.sub_aig_forward_level)
        sub_aig_hf = torch_scatter.scatter_mean(sub_aig_hf, batch.sub_aig_batch, dim=0)
        #encode pm
        pm_hf_g = self.pm_encoder(batch.pm_x, batch.pm_edge_index, batch.pm_forward_level)
        pm_fuse = torch.cat([pm_hf_g,sub_aig_hf[batch.pm_batch]], dim=-1)

        pred_logits = self.mlp(pm_fuse)

        #boundary label
        input_boundary = batch.sub_aig_to_cell[batch.sub_aig_forward_level == 0].unique()
        input_boundary = input_boundary[input_boundary!=-1]
        output_boundary = batch.sub_aig_to_cell[batch.sub_aig_backward_level == 0].unique()
        label = torch.zeros_like(batch.pm_forward_index).long()
        label[input_boundary] = 1
        label[output_boundary] = 2
        label = F.one_hot(label, num_classes=3).float()

        # BCE loss
        loss = F.cross_entropy(pred_logits, label)
        pred_label = pred_logits.argmax(dim=-1)

        # mertric
        if self.training:
            return loss, 0., 0., 0.
        else:
            seg_label = torch.isin(batch.pm_forward_index, batch.sub_aig_to_cell).long().to(device)
            start_nodes = batch.pm_forward_index[pred_label==1]
            end_nodes = batch.pm_forward_index[pred_label==2]
            if start_nodes.shape[0] == 0:
                start_nodes = batch.pm_forward_index[batch.pm_forward_level == 0]
            if end_nodes.shape[0] == 0:
                end_nodes = batch.pm_forward_index[batch.pm_backward_level == 0]

            seg_node = self.find_nodes_between_start_and_end(batch.pm_edge_index, start_nodes, end_nodes)
            pred_seg_label = torch.zeros_like(seg_label).long().to(device)
            if seg_node is not None:
                pred_seg_label[seg_node] = 1
            pred_seg_label[start_nodes] = 1
            pred_seg_label[end_nodes] = 1

            union = torch.logical_or(pred_seg_label == 1, seg_label == 1).float().sum() 
            intersection = torch.logical_and(pred_seg_label == 1, seg_label == 1).float().sum() 
            iou = intersection / (union + 1e-5)
            dice = (2*intersection) / (pred_seg_label.float().sum() + seg_label.float().sum() + 1e-5)
            return loss, 0., iou, dice
        
    def training_step(self, batch, batch_idx):

        loss, acc, iou, dice= self.forward_boundary(batch, batch_idx)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)
        self.training_step_outputs.append({'loss': loss, 'acc': acc, 'iou': iou, 'dice': dice})

        return loss
    
    def validation_step(self, batch, batch_idx):

        loss, acc, iou, dice= self.forward_boundary(batch, batch_idx)
 
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)
        self.log('val_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)
        self.log('val_iou', iou, on_step=True, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)
        self.log('val_dice', dice, on_step=True, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)
        self.val_step_outputs.append({'loss': loss, 'acc': acc, 'iou': iou, 'dice': dice})

        return loss

    def on_validation_epoch_end(self):
        total_acc = sum([x['acc'] for x in self.val_step_outputs])/len(self.val_step_outputs)
        total_iou = sum([x['iou'] for x in self.val_step_outputs])/len(self.val_step_outputs)
        total_dice = sum([x['dice'] for x in self.val_step_outputs])/len(self.val_step_outputs)
        self.log('val_acc_epoch', round(float(total_acc),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('val_iou_epoch', round(float(total_iou),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('val_dice_epoch', round(float(total_dice),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.val_step_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr)
        return optimizer
    
