import networkx as nx
import torch
import torch.nn.functional as F
import torch_geometric.data as gd
from torch_geometric.transforms import AddRandomWalkPE

from gflownet.graphenv import GraphEnv, GraphState, ActionType, Children, Parents


class GraphStateFeaturizer:
    
    def __init__(
        self, 
        env: GraphEnv, 
        random_walk_length=0, 
        add_degree=False, 
        add_clustering_coef=False,
        set_backward=False
    ):
        self.env = env
        self.randomwalk_pe = AddRandomWalkPE(random_walk_length, 'rwpe')
        self.add_degree = add_degree
        self.add_clustering_coef = add_clustering_coef
        self.node_dim = (
            env.num_node_types
            + random_walk_length 
            + (add_degree and env.max_degree + 1)
            + add_clustering_coef
        )
        self.edge_dim = env.num_edge_types
        self.set_backward = set_backward


    def transform(self, states):
        output = {'graph': [], 'children': [], 'parents': []}

        for state in states:
            torch_graph = self.state_to_graph(state)
            
            children = self.env.children(state)
            self.set_forward_attr(torch_graph, children)
            output['children'].append(children)

            if self.set_backward:
                parents = self.env.parents(state)
                self.set_backward_attr(torch_graph, parents)
                output['parents'].append(parents)

            output['graph'].append(torch_graph)
        return output
    

    def state_to_graph(self, state: GraphState) -> gd.Data:
        node_types = torch.tensor([i for i in state.node_types], dtype=torch.long)
        edge_types = torch.tensor([i for i in state.edge_types], dtype=torch.long)
        edge_index = [e for i, j in state.edge_list for e in [(i, j), (j, i)]]
        edge_index = torch.tensor(edge_index, dtype=torch.long).reshape(-1, 2).t().contiguous()
        torch_graph = gd.Data(
            node_types=node_types,
            edge_types=edge_types,
            edge_index=edge_index,
        )
        self.set_graph_features(torch_graph, state)
        return torch_graph
    
    
    def set_graph_features(self, torch_graph: gd.Data, state: GraphState):
        x = F.one_hot(torch_graph.node_types, num_classes=self.env.num_node_types).float()
        if self.randomwalk_pe.walk_length > 0:
            rwpe = self.randomwalk_pe(torch_graph).rwpe
            x = torch.cat([x, rwpe], dim=1)
        if self.add_degree:
            degree = torch.tensor(state.degree, dtype=torch.long)
            deg = F.one_hot(degree, num_classes=self.env.max_degree + 1).float()
            x = torch.cat([x, deg], dim=1)
        if self.add_clustering_coef:
            clustering = nx.clustering(state.to_nx())
            coef = [clustering[i] for i in range(len(clustering))]
            coef = torch.tensor(coef, dtype=torch.float)[:, None]
            x = torch.cat([x, coef], dim=1)
        if x.shape[0] == 0:
            x = torch.zeros(1, x.shape[1], dtype=torch.float)
        edge_types = torch_graph.edge_types.repeat_interleave(2, dim=0)
        edge_attr = F.one_hot(edge_types, num_classes=self.env.num_edge_types).float()
        torch_graph.x = x
        torch_graph.edge_attr = edge_attr
        del torch_graph.node_types
        del torch_graph.edge_types

    

    def set_forward_attr(self, torch_graph: gd.Data, children: Children):
        add_node_index = []
        add_node_type = []
        add_edge_index = []
        add_edge_type = []
        has_stop = False
        for action in children.actions:
            if action.type == ActionType.AddNode:
                # For the empty graph, `action.source` and `action.edge_type` are `None` type.
                # Since we add a virtual node for the empty graph, source = 0.
                # `edge_type` is (arbitrarily) set to 0. Better way to handle this?
                source, edge_type = (action.source or 0, action.edge_type or 0)
                add_node_index.append(source)
                add_node_type.append(action.node_type * self.env.num_edge_types + edge_type)
            elif action.type == ActionType.AddEdge:
                edge = (action.source, action.target)
                add_edge_index.append(edge)
                add_edge_type.append(action.edge_type)
            elif action.type == ActionType.Stop:
                has_stop = True
        
        add_node_index = torch.tensor(add_node_index, dtype=torch.long)
        add_node_type = torch.tensor(add_node_type, dtype=torch.long)
        add_node_mask = F.one_hot(add_node_type, num_classes=self.env.num_node_types * self.env.num_edge_types).bool()
        
        add_edge_index = torch.tensor(add_edge_index, dtype=torch.long).reshape(-1, 2).t().contiguous()
        add_edge_type = torch.tensor(add_edge_type, dtype=torch.long)
        add_edge_mask = F.one_hot(add_edge_type, num_classes=self.env.num_edge_types).bool()

        torch_graph.add_node_index = add_node_index
        torch_graph.add_edge_index = add_edge_index
        torch_graph.add_node_mask = add_node_mask
        torch_graph.add_edge_mask = add_edge_mask
        torch_graph.stop_mask = torch.tensor([has_stop], dtype=torch.bool)


    def set_backward_attr(self, torch_graph: gd.Data, parents: Parents):
        del_node_index = []
        del_edge_index = []
        for action in parents.actions:
            if action.type == ActionType.RemoveNode:
                del_node_index.append(action.source)
            elif action.type == ActionType.RemoveEdge:
                del_edge_index.append((action.source, action.target))

        torch_graph.del_node_index = torch.tensor(del_node_index, dtype=torch.long)
        torch_graph.del_edge_index = torch.tensor(del_edge_index, dtype=torch.long).reshape(-1, 2).t().contiguous()
    

    def collate(self, torch_graphs):
        batch = gd.Batch.from_data_list(torch_graphs, follow_batch=['add_edge_index', 'del_edge_index'])

        stop_batch = batch.stop_mask.nonzero().flatten()
        add_node_batch = batch.batch[batch.add_node_index]
        add_edge_batch = batch.add_edge_index_batch
        batch.fwd_logit_batch = torch.cat([stop_batch, add_node_batch, add_edge_batch], dim=0)
        del batch.add_edge_index_batch
        del batch.add_edge_index_ptr

        if self.set_backward:
            del_node_batch = batch.batch[batch.del_node_index]
            del_edge_batch = batch.del_edge_index_batch
            batch.bck_logit_batch = torch.cat([del_node_batch, del_edge_batch])
            del batch.del_edge_index_batch
            del batch.del_edge_index_ptr
        return batch