import hydra
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.utils import to_dense_batch

from nn.dynamic_graph_constructor import GraphConstructor
from nn.pooling import Aggregator


def to_pyg_batch(node_features, edge_features, edge_index, node_mask):
    data_list = [
        torch_geometric.data.Data(
            x=node_features[i][node_mask[i]],
            edge_index=edge_index[i],
            edge_attr=edge_features[i, edge_index[i][0], edge_index[i][1]],
        )
        for i in range(node_features.shape[0])
    ]
    return torch_geometric.data.Batch.from_data_list(data_list)


class GNNForGeneralization(nn.Module):
    def __init__(
        self,
        d_in,
        d_edge_in,
        d_hid,
        d_out,
        gnn_backbone,
        rev_edge_features,
        graph_features,
        zero_out_bias,
        zero_out_weights,
        sin_emb,
        input_layers,
        use_pos_embed,
        num_probe_features,
        max_num_hidden_layers,
        inr_model=None,
        inp_factor=1,
        stats=None,
        compile=False,
        jit=False,
        input_channels=3,
        linear_as_conv=True,
        flattening_method="repeat_nodes",
        max_spatial_resolution=64,
        num_classes=10,
    ):
        super().__init__()
        self.graph_features = graph_features
        self.out_features = d_out
        self.num_classes = num_classes
        self.rev_edge_features = rev_edge_features

        self.construct_graph = GraphConstructor(
            d_in=d_in,
            d_edge_in=d_edge_in,
            d_node=d_hid,
            d_edge=d_hid,
            d_out=d_out,
            max_num_hidden_layers=max_num_hidden_layers,
            rev_edge_features=rev_edge_features,
            zero_out_bias=zero_out_bias,
            zero_out_weights=zero_out_weights,
            sin_emb=sin_emb,
            input_layers=input_layers,
            use_pos_embed=use_pos_embed,
            inp_factor=inp_factor,
            num_probe_features=num_probe_features,
            inr_model=inr_model,
            stats=stats,
            input_channels=input_channels,
            linear_as_conv=linear_as_conv,
            flattening_method=flattening_method,
            max_spatial_resolution=max_spatial_resolution,
            num_classes=num_classes,
        )

        num_graph_features = d_hid
        if graph_features == "cat_last_layer":
            num_graph_features = num_classes * d_hid
        elif graph_features == "cat_all_layers":
            num_graph_features = (input_channels + max_num_hidden_layers + num_classes) * d_hid

        if graph_features in ('attentional_aggregation', 'set_transformer', 'graph_multiset_transformer'):
            self.pool = Aggregator(d_hid, d_hid, d_hid, graph_features)

        self.proj_out = nn.Sequential(
            nn.Linear(num_graph_features, d_hid),
            nn.ReLU(),
            nn.Linear(d_hid, d_hid),
            nn.ReLU(),
            nn.Linear(d_hid, d_out),
        )

        gnn_kwargs = dict()
        gnn_kwargs["deg"] = torch.tensor(gnn_backbone["deg"], dtype=torch.long)

        self.gnn = hydra.utils.instantiate(gnn_backbone, **gnn_kwargs)
        if jit:
            self.gnn = torch.jit.script(self.gnn)
        if compile:
            self.gnn = torch_geometric.compile(self.gnn)

    def forward(self, batch):
        # self.register_buffer("edge_index", batch.edge_index, persistent=False)

        node_features, edge_features, _, node_mask = self.construct_graph(batch)

        if self.rev_edge_features:
            edge_index = [
                torch.cat([batch[i].edge_index, batch[i].edge_index.flip(dims=(0,))], dim=-1)
                for i in range(len(batch))
            ]
        else:
            edge_index = [batch[i].edge_index for i in range(len(batch))]

        new_batch = to_pyg_batch(node_features, edge_features, edge_index, node_mask)
        out_node, out_edge = self.gnn(
            x=new_batch.x, edge_index=new_batch.edge_index, edge_attr=new_batch.edge_attr
        )
        node_features = to_dense_batch(out_node, new_batch.batch)[0]

        valid_layer_indices = (
            torch.arange(node_mask.shape[1], device=node_mask.device)[None, :]
            * node_mask
        )
        last_layer_indices = valid_layer_indices.topk(k=self.num_classes, dim=1).values.fliplr()
        batch_range = torch.arange(node_mask.shape[0], device=node_mask.device)[:, None]

        if self.graph_features == "mean":
            graph_features = node_features.mean(dim=1)
        elif self.graph_features == "max":
            graph_features = node_features.max(dim=1).values
        elif self.graph_features == "last_layer":
            graph_features = node_features[batch_range, last_layer_indices].mean(dim=1)
        elif self.graph_features == "cat_last_layer":
            graph_features = node_features[batch_range, last_layer_indices].flatten(1, 2)
        elif self.graph_features.startswith("layer_"):
            layer_idx = [
                torch.cumsum(torch.tensor([0] + layer_layout), dim=0)
                for layer_layout in batch.layer_layout
            ]
            layer = int(self.graph_features.split("_")[1])
            # TODO: This only works for same layouts
            graph_features = node_features[
                :, layer_idx[layer]:layer_idx[layer + 1]
            ].mean(dim=1)
        elif self.graph_features == "cat_all_layers":
            layer_idx = [
                torch.cumsum(torch.tensor([0] + layer_layout), dim=0)
                for layer_layout in batch.layer_layout
            ]
            # TODO: This only works for same layouts
            graph_features = torch.cat(
                [
                    node_features[:, self.layer_idx[i]:self.layer_idx[i + 1]].mean(
                        dim=1
                    )
                    for i in range(len(self.nodes_per_layer))
                ],
                dim=1,
            )
        elif self.graph_features in ('last_attentional_aggregation',
                                     'last_set_transformer',
                                     'last_graph_multiset_transformer'):
            graph_features = self.pool(node_features[batch_range, last_layer_indices])
        elif self.graph_features in ('attentional_aggregation',
                                     'set_transformer',
                                     'graph_multiset_transformer'):
            graph_features = self.pool(node_features)

        return self.proj_out(graph_features)
