"""
An adoption of Graph-Q-SAT
https://github.com/NVIDIA/GraphQSat
"""

import torch
from torch import nn
from torch.nn import ReLU

from torch_geometric.data import Batch

from mas_sat.graph.vcg import VCGGraph
from mas_sat.model.base import BaseModel
from mas_sat.model.mlp import get_mlp
from mas_sat.utils.scatter import scatter_reduce

class Encoder(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.edge_mlp = get_mlp(2, dim//2, 0)
        self.node_mlp = get_mlp(2, dim//2, 0)
        self.global_mlp = get_mlp(1, dim//2, 0)

    def forward(self, graph):
        return {
            "variable": self.node_mlp(graph["variable"].x),
            "original_clause": self.node_mlp(graph["original_clause"].x),
            "instance": self.global_mlp(graph["instance"].x),
            ("variable", "in", "original_clause"): self.edge_mlp(graph["variable", "in", "original_clause"].edge_attr),
            ("original_clause", "has", "variable"): self.edge_mlp(graph["original_clause", "has", "variable"].edge_attr)
        }

class Process(nn.Module):
    def __init__(self, dim, e2v_agg="sum", n_hidden=1):
        super().__init__()
        if e2v_agg not in ["sum", "mean"]:
            raise ValueError("Unknown aggregation function.")
        self.e2v_agg = e2v_agg
        in_dim = dim//2 + dim

        self.edge_mlp = get_mlp(in_dim*4, dim, n_hidden, dim)
        self.node_mlp = get_mlp(in_dim*2+dim, dim, n_hidden, dim)
        self.global_mlp = get_mlp(in_dim+dim*2, dim, n_hidden, dim)

    def edge_model(self, latent_dict, edge_index_dict, batch_dict):
        variable_indices, clause_indices = edge_index_dict["variable", "in", "original_clause"]
        u = latent_dict["instance"]

        out_v2c = torch.cat([
            latent_dict["variable"][variable_indices],
            latent_dict["original_clause"][clause_indices],
            latent_dict["variable", "in", "original_clause"],
            u[batch_dict["variable", "in", "original_clause"]]
        ], dim=1)
        latent_dict["variable", "in", "original_clause"] = self.edge_mlp(out_v2c)
        
        out_c2v = torch.cat([
            latent_dict["original_clause"][clause_indices],
            latent_dict["variable"][variable_indices],
            latent_dict["original_clause", "has", "variable"],
            u[batch_dict["original_clause", "has", "variable"]]
        ], dim=1)
        latent_dict["original_clause", "has", "variable"] = self.edge_mlp(out_c2v)
        
        return latent_dict

    def node_model(self, latent_dict, edge_index_dict, batch_dict):
        u = latent_dict["instance"]
        n_variable = latent_dict["variable"].shape[0]
        n_clause = latent_dict["original_clause"].shape[0]
        variable_indices, clause_indices = edge_index_dict["variable", "in", "original_clause"]

        x = latent_dict["variable"]
        out = scatter_reduce(latent_dict["variable", "in", "original_clause"], variable_indices, reduce=self.e2v_agg, dim=0, dim_size=n_variable)
        ui = u[batch_dict["variable"]]
        out_v = torch.cat([x, out, ui], dim=1)
        latent_dict["variable"] = self.node_mlp(out_v)

        x = latent_dict["original_clause"]
        out = scatter_reduce(latent_dict["original_clause", "has", "variable"], clause_indices, reduce=self.e2v_agg, dim=0, dim_size=n_clause)
        ui = u[batch_dict["original_clause"]]
        out_c = torch.cat([x, out, ui], dim=1)
        latent_dict["original_clause"] = self.node_mlp(out_c)

        return latent_dict

    def global_model(self, latent_dict, batch_dict):
        node_feature = torch.cat([latent_dict["variable"], latent_dict["original_clause"]], dim=0)
        node_batch = torch.cat([batch_dict["variable"], batch_dict["original_clause"]], dim=0)
        edge_feature = torch.cat([latent_dict["variable", "in", "original_clause"], latent_dict["original_clause", "has", "variable"]], dim=0)
        edge_batch = torch.cat([batch_dict["variable", "in", "original_clause"], batch_dict["original_clause", "has", "variable"]], dim=0)

        out = torch.cat([
            latent_dict["instance"],
            scatter_reduce(node_feature, node_batch, reduce="mean", dim=0),
            scatter_reduce(edge_feature, edge_batch, reduce="mean", dim=0)
        ], dim=1)
        latent_dict["instance"] = self.global_mlp(out)
        return latent_dict

    def forward(self, encoded_dict, latent_dict, edge_index_dict, batch_dict):
        latent_dict = {k: torch.cat([encoded_dict[k], v], dim=1) for k, v in latent_dict.items()}
        latent_dict = self.edge_model(latent_dict, edge_index_dict, batch_dict)
        latent_dict = self.node_model(latent_dict, edge_index_dict, batch_dict)
        latent_dict = self.global_model(latent_dict, batch_dict)
        return latent_dict

class Model(BaseModel):
    def __init__(self, args) -> None:
        if not args.graph == "vcg":
            raise ValueError("Model is only compatible with VCG")
        super().__init__(args)

        # components
        self.encoder = Encoder(args.dim)
        self.core = Process(args.dim)

    def get_latent_dict(self, graph):
        return {
            attr: graph[attr].latent for attr in [
                "variable", "original_clause", "instance",
                ("variable", "in", "original_clause"),
                ("original_clause", "has", "variable")
            ]
        }

    def get_edge_index_dict(self, graph):
        return {
            attr: graph[attr].edge_index for attr in [
                ("variable", "in", "original_clause"),
            ]
        }

    def get_kwargs(self, graph):
        # encoded_dict
        encoded_dict = self.encoder(graph)

        # batch_dict
        if isinstance(graph, Batch):
            batch_dict = {
                attr: graph[attr].batch for attr in [
                    "variable", "original_clause",
                    ("variable", "in", "original_clause"),
                    ("original_clause", "has", "variable")
                ]
            }
        else:
            device = graph["variable"].x.device
            batch_dict = {
                "variable": torch.zeros(graph["variable"].num_nodes, dtype=torch.long, device=device),
                "original_clause": torch.zeros(graph["original_clause"].num_nodes, dtype=torch.long, device=device),
                ("variable", "in", "original_clause"): torch.zeros(graph["variable", "in", "original_clause"].num_edges, dtype=torch.long, device=device),
                ("original_clause", "has", "variable"): torch.zeros(graph["original_clause", "has", "variable"].num_edges, dtype=torch.long, device=device)
            }

        return {
            "encoded_dict": encoded_dict,
            "batch_dict": batch_dict
        }

    def step(self, latent_dict, edge_index_dict, **kwargs):
        encoded_dict = kwargs["encoded_dict"]
        batch_dict = kwargs["batch_dict"]
        return self.core(encoded_dict, latent_dict, edge_index_dict, batch_dict)