import torch
from torch import nn
import torch.nn.functional as F

from nn.dynamic_graph_constructor import GraphConstructor
from nn.relation_transformer import RTLayer
from nn.pooling import Aggregator


class DynamicRelationTransformer(nn.Module):
    def __init__(
        self,
        d_in,
        d_edge_in,
        d_node,
        d_edge,
        d_attn_hid,
        d_node_hid,
        d_edge_hid,
        d_out_hid,
        d_out,
        n_layers,
        n_heads,
        num_probe_features,
        max_num_hidden_layers,
        inr_model=None,
        dropout=0.0,
        node_update_type="rt",
        disable_edge_updates=False,
        use_cls_token=True,
        graph_features="mean",
        rev_edge_features=False,
        zero_out_bias=False,
        zero_out_weights=False,
        sin_emb=False,
        input_layers=1,
        use_pos_embed=True,
        inp_factor=1.0,
        modulate_v=True,
        use_ln=True,
        tfixit_init=False,
        stats=None,
        input_channels=3,
        linear_as_conv=True,
        flattening_method='repeat_nodes',
        max_spatial_resolution=64,
        num_classes=10,
    ):
        super().__init__()
        assert use_cls_token == (graph_features == "cls_token")
        self.graph_features = graph_features
        self.rev_edge_features = rev_edge_features
        self.out_features = d_out
        self.num_classes = num_classes
        self.construct_graph = GraphConstructor(
            d_in=d_in,
            d_edge_in=d_edge_in,
            d_node=d_node,
            d_edge=d_edge,
            d_out=d_out,
            max_num_hidden_layers=max_num_hidden_layers,
            rev_edge_features=rev_edge_features,
            zero_out_bias=zero_out_bias,
            zero_out_weights=zero_out_weights,
            sin_emb=sin_emb,
            use_pos_embed=use_pos_embed,
            input_layers=input_layers,
            inp_factor=inp_factor,
            num_probe_features=num_probe_features,
            inr_model=inr_model,
            stats=stats,
            input_channels=input_channels,
            linear_as_conv=linear_as_conv,
            flattening_method=flattening_method,
            max_spatial_resolution=max_spatial_resolution,
            num_classes=num_classes,
        )
        self.use_cls_token = use_cls_token
        if use_cls_token:
            self.cls_token = nn.Parameter(torch.randn(d_node))

        self.layers = nn.ModuleList(
            [
                torch.jit.script(
                    RTLayer(
                        d_node,
                        d_edge,
                        d_attn_hid,
                        d_node_hid,
                        d_edge_hid,
                        n_heads,
                        float(dropout),
                        node_update_type=node_update_type,
                        disable_edge_updates=(
                            (disable_edge_updates or (i == n_layers - 1))
                            and (graph_features != "mean_edge")),
                        modulate_v=modulate_v,
                        use_ln=use_ln,
                        tfixit_init=tfixit_init,
                        n_layers=n_layers,
                    )
                )
                for i in range(n_layers)
            ]
        )
        num_graph_features = d_node
        if graph_features == "cat_last_layer":
            num_graph_features = num_classes * d_node
        elif graph_features == "cat_all_layers":
            num_graph_features = (d_in + 3 * 16 + num_classes) * d_node  # TODO: Remove hardcoding

        if graph_features in (
                'attentional_aggregation', 'set_transformer',
                'graph_multiset_transformer',
                'last_attentional_aggregation', 'last_set_transformer',
                'last_graph_multiset_transformer',
        ):
            self.pool = Aggregator(d_node, d_out_hid, d_node, graph_features)

        self.proj_out = nn.Sequential(
            nn.Linear(num_graph_features, d_out_hid),
            nn.ReLU(),
            nn.Linear(d_out_hid, d_out_hid),
            nn.ReLU(),
            nn.Linear(d_out_hid, d_out),
        )

    def forward(self, batch):
        node_features, edge_features, mask, node_mask = self.construct_graph(batch)

        if self.use_cls_token:
            node_features = torch.cat(
                [
                    # repeat(self.cls_token, "d -> b 1 d", b=node_features.size(0)),
                    self.cls_token.unsqueeze(0).expand(node_features.size(0), 1, -1),
                    node_features,
                ],
                dim=1,
            )
            edge_features = F.pad(edge_features, (0, 0, 1, 0, 1, 0), value=0)

        for layer in self.layers:
            node_features, edge_features = layer(node_features, edge_features, mask)

        valid_layer_indices = (
            torch.arange(node_mask.shape[1], device=node_mask.device)[None, :]
            * node_mask
        )
        last_layer_indices = valid_layer_indices.topk(k=self.num_classes, dim=1).values.fliplr()
        batch_range = torch.arange(node_mask.shape[0], device=node_mask.device)[:, None]

        if self.graph_features == "cls_token":
            graph_features = node_features[:, 0]
        elif self.graph_features == "mean":
            graph_features = node_features.mean(dim=1)
        elif self.graph_features == "max":
            graph_features = node_features.max(dim=1).values
        elif self.graph_features == "last_layer":
            graph_features = node_features[batch_range, last_layer_indices].mean(dim=1)
        elif self.graph_features == "cat_last_layer":
            graph_features = node_features[batch_range, last_layer_indices].flatten(1, 2)
        elif self.graph_features == "cat_all_layers":
            graph_features = node_features.flatten(1, 2)
        elif self.graph_features == "mean_edge":
            graph_features = edge_features.mean(dim=(1, 2))
        elif self.graph_features == "max_edge":
            graph_features = edge_features.flatten(1, 2).max(dim=1).values
        elif self.graph_features == "last_layer_edge":
            graph_features = edge_features[batch_range, last_layer_indices, :].mean(dim=(1, 2))
        elif self.graph_features in ('last_attentional_aggregation',
                                     'last_set_transformer',
                                     'last_graph_multiset_transformer'):
            graph_features = self.pool(node_features[batch_range, last_layer_indices])
        elif self.graph_features in ('attentional_aggregation',
                                     'set_transformer',
                                     'graph_multiset_transformer'):
            # FIXME: Node features are not masked, some contain garbage
            graph_features = self.pool(node_features)

        return self.proj_out(graph_features)
