import dgl
import torch
from torch.nn import MultiheadAttention
from dgl.nn.pytorch.conv import EGATConv
from hyperGraph_utils.HGNNPlus_Model import HGNNPlusEncoder
from hyperGraph_utils.utils import batch_hypergraph


def collate_LGC(list_of_states):
    # product_dgl, product_fp, LgIdx_list, mask, t
    batch_product_dgl, batch_product_fp, batch_lg_list, batch_mask, _t = map(list, zip(*list_of_states))

    batch_product_dgl = dgl.batch(batch_product_dgl)
    batch_product_fp = torch.cat(batch_product_fp, dim=0)  # bxd
    # batch_lg_list = [[], [], ...]
    batch_mask = torch.cat(batch_mask, dim=0)  # bx(v+1)
    #      dgl,               bxd,              [[], [], ...], bx(v+1)
    return batch_product_dgl, batch_product_fp, batch_lg_list, batch_mask


class EGATLayer(torch.nn.Module):
    def __init__(self, hidden_dimension, num_heads):
        super(EGATLayer, self).__init__()
        self.hidden_dimension = hidden_dimension
        self.num_heads = num_heads
        self.egat_layer = EGATConv(in_node_feats=hidden_dimension,
                                   in_edge_feats=hidden_dimension,
                                   out_node_feats=hidden_dimension,
                                   out_edge_feats=hidden_dimension,
                                   num_heads=num_heads)
        self.mlp_node = torch.nn.Linear(hidden_dimension * num_heads, hidden_dimension)
        self.mlp_edge = torch.nn.Linear(hidden_dimension * num_heads, hidden_dimension)

    def forward(self, graph, nfeats, efeats):
        with graph.local_scope():
            new_node_feats, new_edge_feats = self.egat_layer(graph, nfeats, efeats)

            new_node_feats = new_node_feats.reshape(-1, self.hidden_dimension * self.num_heads)
            new_node_feats = self.mlp_node(new_node_feats)
            new_node_feats = torch.nn.functional.elu(new_node_feats)  # nxd

            new_edge_feats = new_edge_feats.reshape(-1, self.hidden_dimension * self.num_heads)
            new_edge_feats = self.mlp_edge(new_edge_feats)
            new_edge_feats = torch.nn.functional.elu(new_edge_feats)  # Exd

        return new_node_feats, new_edge_feats  # nxd, Exd


class EGATModel(torch.nn.Module):
    def __init__(self, hidden_dimension, num_heads, num_layers, residual=True):
        super(EGATModel, self).__init__()
        self.residual = residual
        self.layers = torch.nn.ModuleList([EGATLayer(hidden_dimension=hidden_dimension, num_heads=num_heads)
                                           for _ in range(num_layers)])

    def forward(self, graph, nfeats, efeats):
        with graph.local_scope():
            for layer in self.layers:
                if self.residual:
                    residual_n = nfeats
                    residual_e = efeats
                    nfeats, efeats = layer(graph, nfeats, efeats)
                    nfeats = nfeats + residual_n
                    efeats = efeats + residual_e
                else:
                    nfeats, efeats = layer(graph, nfeats, efeats)
        return nfeats, efeats  # nxd, Exd


class BaseLGCEncoder(torch.nn.Module):
    def __init__(self, hidden_dimension, num_egat_heads, num_egat_layers, lg_hypergraph_dhg, num_of_LayerHypergraph=1,
                 drop_rate=0, residual=True, have_fp=True, have_structure=True):
        super(BaseLGCEncoder, self).__init__()
        if have_structure:
            self.lg_hypergraph_dhg = lg_hypergraph_dhg
            self.hypergraph_encoder = HGNNPlusEncoder(
                layer_info=[hidden_dimension for _ in range(num_of_LayerHypergraph + 1)],
                drop_rate=drop_rate, res=False)
        self.lg_emb = torch.nn.Parameter(torch.FloatTensor(lg_hypergraph_dhg.num_v, hidden_dimension))  # vxd
        torch.nn.init.xavier_normal_(self.lg_emb)

        self.init_placed_lg = torch.nn.Parameter(torch.FloatTensor(1, hidden_dimension))  # 1xd
        torch.nn.init.xavier_normal_(self.init_placed_lg)

        # ablation
        self.residual = residual
        self.have_fp = have_fp
        self.have_structure = have_structure

        # encode FP
        if have_fp:
            self.dense_FP = torch.nn.Linear(2048, hidden_dimension)

        # encode product
        self.dense_init_nfeats = torch.nn.Linear(81, hidden_dimension)
        self.dense_init_efeats = torch.nn.Linear(17, hidden_dimension)
        self.product_graph_encoder = EGATModel(hidden_dimension, num_egat_heads, num_egat_layers, residual=residual)

        if have_fp:
            self.dense_base = torch.nn.Linear(4 * hidden_dimension, hidden_dimension)
        else:
            self.dense_base = torch.nn.Linear(3 * hidden_dimension, hidden_dimension)

    def forward(self, batch_product_dgl, batch_product_fp, batch_lg_list):  # batch_dgl, bn, bxd, b(n+1)x1
        with batch_product_dgl.local_scope():
            # encode lg hypergraph
            if self.have_structure:
                lg_emb = self.hypergraph_encoder(hg=self.lg_hypergraph_dhg, X=self.lg_emb)  # vxd
            else:
                lg_emb = self.lg_emb  # vxd

            # batch_placed_lg
            batch_placed_lg = []
            for lg_list in batch_lg_list:
                if len(lg_list) == 0:
                    batch_placed_lg.append(self.init_placed_lg)
                else:
                    placed_lg = lg_emb[lg_list]
                    placed_lg = torch.sum(placed_lg, dim=0, keepdim=True)  # 1xd
                    batch_placed_lg.append(placed_lg)
            batch_placed_lg = torch.cat(batch_placed_lg, dim=0)  # bxd

            # encode FP
            if self.have_fp:
                encoded_product_fp = torch.nn.functional.elu(self.dense_FP(batch_product_fp))  # bxd

            # encode product to node-embedding
            init_x = self.dense_init_nfeats(batch_product_dgl.ndata['x'])  # nxd
            init_e = self.dense_init_efeats(batch_product_dgl.edata['e'])  # Exd
            node_emb, edge_emb = self.product_graph_encoder(graph=batch_product_dgl, nfeats=init_x,
                                                            efeats=init_e)  # bnxd, bExd

            # product_graph_level_embedding
            batch_product_dgl.ndata['node_emb'] = node_emb
            batch_product_dgl.edata['edge_emb'] = edge_emb
            batch_graph_emb_by_node = dgl.readout_nodes(graph=batch_product_dgl, feat='node_emb', op="mean")  # bxd
            batch_graph_emb_by_edge = dgl.readout_edges(graph=batch_product_dgl, feat='edge_emb', op="mean")  # bxd

            if self.have_fp:
                batch_graph_emb = torch.cat([batch_graph_emb_by_node, batch_graph_emb_by_edge, encoded_product_fp,
                                             batch_placed_lg], dim=1)  # bx4d
            else:
                batch_graph_emb = torch.cat([batch_graph_emb_by_node, batch_graph_emb_by_edge, batch_placed_lg],
                                            dim=1)  # bx3d

            batch_graph_emb = self.dense_base(batch_graph_emb)  # bxd
            batch_graph_emb = torch.nn.functional.elu(batch_graph_emb)  # bxd
            return batch_graph_emb  # bnxd, bxd


class LGCA2CNet(torch.nn.Module):
    def __init__(self, hidden_dimension, num_egat_heads, num_egat_layers, lg_hypergraph_dhg, num_of_LayerHypergraph,
                 drop_rate=0, residual=True, have_fp=True, have_structure=True):
        super(LGCA2CNet, self).__init__()
        # base
        self.base_encoder = BaseLGCEncoder(hidden_dimension=hidden_dimension, num_egat_heads=num_egat_heads,
                                           num_egat_layers=num_egat_layers, lg_hypergraph_dhg=lg_hypergraph_dhg,
                                           num_of_LayerHypergraph=num_of_LayerHypergraph, drop_rate=drop_rate,
                                           residual=residual, have_fp=have_fp, have_structure=have_structure)

        # policy
        self.dense_policy0 = torch.nn.Linear(hidden_dimension, hidden_dimension)
        self.dense_policy1 = torch.nn.Linear(hidden_dimension, lg_hypergraph_dhg.num_v + 1)

        # value
        self.dense_value0 = torch.nn.Linear(hidden_dimension, hidden_dimension)
        self.dense_value1 = torch.nn.Linear(hidden_dimension, 1)

    def policy(self, batch_product_dgl, batch_product_fp, batch_lg_list, batch_mask, logits=False):
        rst = self.base_encoder(batch_product_dgl, batch_product_fp, batch_lg_list)  # bxd
        rst = self.dense_policy0(rst)  # bxd
        rst = torch.nn.functional.elu(rst)  # bxd
        rst = self.dense_policy1(rst)  # bx(v+1)
        # mask
        rst = rst + batch_mask * -1e9  # bxd
        if not logits:
            rst = torch.nn.functional.softmax(rst, dim=-1)  # bx(v+1)
        return rst

    def value(self, batch_product_dgl, batch_product_fp, batch_lg_list):
        rst = self.base_encoder(batch_product_dgl, batch_product_fp, batch_lg_list)  # bxd
        rst = self.dense_value0(rst)  # bxd
        rst = torch.nn.functional.elu(rst)  # bxd
        rst = self.dense_value1(rst)  # bx1
        return rst



