import torch
from torch import nn
from torch.nn import functional as F
from time import time
import torch_geometric.nn as gnn
from collections import defaultdict, deque
import pytorch_lightning as pl
from .mlp import MLP
import torch_scatter
from torch_sparse import SparseTensor

class SAGE_MULT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers=4,
                 dropout=0.5):
        super(SAGE_MULT, self).__init__()
        self.num_layers = num_layers

        self.convs = torch.nn.ModuleList()
        self.convs.append(gnn.SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(gnn.SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(gnn.SAGEConv(hidden_channels, hidden_channels))
        
        # two linear layer for predictions
        self.linear = torch.nn.ModuleList()
        self.linear.append(nn.Linear(hidden_channels, hidden_channels, bias=False))
        self.bn0 = nn.BatchNorm1d(hidden_channels)

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for lin in self.linear:
            lin.reset_parameters()

    def forward(self, x, edge_index):
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=0.5, training=self.training)
            
        x = self.linear[0](x)
        x = self.bn0(F.relu(x))

        return x

class BoundaryGamora(pl.LightningModule):

    def __init__(self,args):
        super().__init__()  
        self.args = args
        self.aig_encoder = SAGE_MULT(in_channels=3, hidden_channels=args.hidden)
        self.pm_encoder = SAGE_MULT(in_channels=64, hidden_channels=args.hidden)
        self.training_step_outputs = []
        self.test_step_outputs = []
        self.val_step_outputs = []
        self.save_hyperparameters()
        self.mlp= MLP(2*self.args.hidden, 2*self.args.hidden, 3, num_layer=3)

    def find_nodes_between_start_and_end(self, edge_index, start_nodes, end_nodes):
        device = edge_index.device
        num_nodes = int(edge_index.max().item()) + 1
        adj = SparseTensor(
            row=edge_index[0],
            col=edge_index[1],
            sparse_sizes=(num_nodes, num_nodes)
        )
        start_nodes = torch.as_tensor(start_nodes, device=device).unique()
        end_nodes = torch.as_tensor(end_nodes,   device=device).unique()
        end_set = set(end_nodes.tolist())

        crow, col,_ = adj.csr()  # crow: (num_nodes+1,), col: (num_edges,)
        visited = torch.zeros(num_nodes, dtype=torch.bool, device=device)
        visited[start_nodes] = True
        queue = start_nodes.clone()
        result_set = set(start_nodes.tolist())

        while queue.numel() > 0:
            current = int(queue[0].item())
            queue = queue[1:]

            if current in end_set:
                continue

            start_ptr, end_ptr = int(crow[current].item()), int(crow[current + 1].item())
            neighbors = col[start_ptr:end_ptr]

            unvisited = neighbors[~visited[neighbors]]
            if unvisited.numel() == 0:
                continue

            visited[unvisited] = True
            result_set.update(unvisited.tolist())
            queue = torch.cat([queue, unvisited])


        return torch.tensor(list(result_set), dtype=torch.long, device=device)

    
    def forward_boundary(self, batch, batch_idx):
        bs = batch.batch_size
        device = batch.pm_x.device

        #encode aig
        sub_aig_hf = self.aig_encoder(F.one_hot(batch.sub_aig_gate_type.squeeze(-1), num_classes=3).float(), batch.sub_aig_edge_index)
        sub_aig_hf = torch_scatter.scatter_mean(sub_aig_hf, batch.sub_aig_batch, dim=0)
        #encode pm
        pm_hf_g = self.pm_encoder(batch.pm_x, batch.pm_edge_index)
        pm_fuse = torch.cat([pm_hf_g,sub_aig_hf[batch.pm_batch]], dim=-1)

        pred_logits = self.mlp(pm_fuse)

        #boundary label
        input_boundary = batch.sub_aig_to_cell[batch.sub_aig_forward_level == 0].unique()
        input_boundary = input_boundary[input_boundary!=-1]
        output_boundary = batch.sub_aig_to_cell[batch.sub_aig_backward_level == 0].unique()

        #prediction
        label = torch.zeros_like(batch.pm_forward_index).long()
        label[input_boundary] = 1
        label[output_boundary] = 2
        label = F.one_hot(label, num_classes=3).float()

        #BCE loss
        loss = F.cross_entropy(pred_logits, label)
        pred_label = pred_logits.argmax(dim=-1)
        
        # mertric
        if self.training:
            return loss, 0., 0., 0.
        else:
            seg_label = torch.isin(batch.pm_forward_index, batch.sub_aig_to_cell).long().to(device)
            start_nodes = batch.pm_forward_index[pred_label==1]
            end_nodes = batch.pm_forward_index[pred_label==2]
            
            if start_nodes.shape[0] == 0:
                start_nodes = batch.pm_forward_index[batch.pm_forward_level == 0]
            if end_nodes.shape[0] == 0:
                end_nodes = batch.pm_forward_index[batch.pm_backward_level == 0]

            seg_node = self.find_nodes_between_start_and_end(batch.pm_edge_index, start_nodes, end_nodes)
            pred_seg_label = torch.zeros_like(seg_label).long().to(device)
            if seg_node is not None:
                pred_seg_label[seg_node] = 1
            pred_seg_label[start_nodes] = 1
            pred_seg_label[end_nodes] = 1
            union = torch.logical_or(pred_seg_label == 1, seg_label == 1).float().sum() 
            intersection = torch.logical_and(pred_seg_label == 1, seg_label == 1).float().sum() 
            iou = intersection / (union + 1e-5)
            dice = (2*intersection) / (pred_seg_label.float().sum() + seg_label.float().sum() + 1e-5)
            return loss, 0., iou, dice

    def training_step(self, batch, batch_idx):


        loss, acc, iou, dice= self.forward_boundary(batch, batch_idx)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)

        self.training_step_outputs.append({'loss': loss, 'acc': acc, 'iou': iou, 'dice': dice})

        return loss
    
    def validation_step(self, batch, batch_idx):

        loss, acc, iou, dice= self.forward_boundary(batch, batch_idx)
 
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)
        self.log('val_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)
        self.log('val_iou', iou, on_step=True, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)
        self.log('val_dice', dice, on_step=True, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)
        self.val_step_outputs.append({'loss': loss, 'acc': acc, 'iou': iou, 'dice': dice})

        return loss

    def on_validation_epoch_end(self):
        total_acc = sum([x['acc'] for x in self.val_step_outputs])/len(self.val_step_outputs)
        total_iou = sum([x['iou'] for x in self.val_step_outputs])/len(self.val_step_outputs)
        total_dice = sum([x['dice'] for x in self.val_step_outputs])/len(self.val_step_outputs)
        self.log('val_acc_epoch', round(float(total_acc),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('val_iou_epoch', round(float(total_iou),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.log('val_dice_epoch', round(float(total_dice),4), on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        self.val_step_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr)
        return optimizer
    
