import copy
import json

import dgl
import torch
from torch.nn import MultiheadAttention
from dgl.nn.pytorch.conv import EGATConv
from hyperGraph_utils.HGNNPlus_Model import HGNNPlusEncoder


def state_to_device(state, device):
    product_dgl = state[0].to(device)
    product_fp = state[1].to(device)
    Placed_RcNodeIdx = state[2].to(device)
    RcNodeIdx_list = state[3]
    LgIdx_list = state[4]
    tag = state[5]
    mask = state[6].to(device)
    t = state[7]
    return [product_dgl, product_fp, Placed_RcNodeIdx, RcNodeIdx_list, LgIdx_list, tag, mask, t]


def ListState_to_device(list_state, device):
    rst = [state_to_device(_state, device) for _state in list_state]
    return rst


class EGATLayer(torch.nn.Module):
    def __init__(self, hidden_dimension, num_heads):
        super(EGATLayer, self).__init__()
        self.hidden_dimension = hidden_dimension
        self.num_heads = num_heads
        self.egat_layer = EGATConv(in_node_feats=hidden_dimension,
                                   in_edge_feats=hidden_dimension,
                                   out_node_feats=hidden_dimension,
                                   out_edge_feats=hidden_dimension,
                                   num_heads=num_heads)
        self.mlp_node = torch.nn.Linear(hidden_dimension * num_heads, hidden_dimension)
        self.mlp_edge = torch.nn.Linear(hidden_dimension * num_heads, hidden_dimension)

    def forward(self, graph, nfeats, efeats):
        with graph.local_scope():
            new_node_feats, new_edge_feats = self.egat_layer(graph, nfeats, efeats)

            new_node_feats = new_node_feats.reshape(-1, self.hidden_dimension * self.num_heads)
            new_node_feats = self.mlp_node(new_node_feats)
            new_node_feats = torch.nn.functional.elu(new_node_feats)  # nxd

            new_edge_feats = new_edge_feats.reshape(-1, self.hidden_dimension * self.num_heads)
            new_edge_feats = self.mlp_edge(new_edge_feats)
            new_edge_feats = torch.nn.functional.elu(new_edge_feats)  # Exd

        return new_node_feats, new_edge_feats  # nxd, Exd


class EGATModel(torch.nn.Module):
    def __init__(self, hidden_dimension, num_heads, num_layers, residual=True):
        super(EGATModel, self).__init__()
        self.residual = residual
        self.layers = torch.nn.ModuleList([EGATLayer(hidden_dimension=hidden_dimension, num_heads=num_heads)
                                           for _ in range(num_layers)])

    def forward(self, graph, nfeats, efeats):
        with graph.local_scope():
            for layer in self.layers:
                if self.residual:
                    residual_n = nfeats
                    residual_e = efeats
                    nfeats, efeats = layer(graph, nfeats, efeats)
                    nfeats = nfeats + residual_n
                    efeats = efeats + residual_e
                else:
                    nfeats, efeats = layer(graph, nfeats, efeats)
        return nfeats, efeats  # nxd, Exd


class BaseStateEncoderOneState(torch.nn.Module):
    def __init__(self, hidden_dimension, num_egat_heads, num_egat_layers, lg_hypergraph_dhg, residual=True, have_fp=True):
        super(BaseStateEncoderOneState, self).__init__()
        self.lg_hypergraph_dhg = lg_hypergraph_dhg
        # ablation
        self.residual = residual
        self.have_fp = have_fp

        # encode FP
        if have_fp:
            self.dense_FP = torch.nn.Linear(2048, hidden_dimension)

        # encode product, tag=0
        self.dense_init_nfeats = torch.nn.Linear(81, hidden_dimension)
        self.dense_init_efeats = torch.nn.Linear(17, hidden_dimension)
        self.product_graph_encoder = EGATModel(hidden_dimension, num_egat_heads, num_egat_layers, residual=residual)
        self.Placed_rc_embedding = torch.nn.Embedding(2, hidden_dimension)
        if have_fp:
            self.dense_product_graph_level = torch.nn.Linear(3 * hidden_dimension, hidden_dimension)
        else:
            self.dense_product_graph_level = torch.nn.Linear(2 * hidden_dimension, hidden_dimension)

        # encode hypergraph, tag=1
        self.init_placed_lg = torch.nn.Parameter(torch.FloatTensor(1, hidden_dimension))  # 1xd
        torch.nn.init.xavier_normal_(self.init_placed_lg)

        self.dense = torch.nn.Linear(2 * hidden_dimension, hidden_dimension)

    def forward(self, state, lg_embedding):
        product_dgl = state[0]
        product_fp = state[1]
        Placed_RcNodeIdx = state[2]
        RcNodeIdx_list = state[3]
        LgIdx_list = state[4]
        tag = state[5]
        mask = state[6]
        t = state[7]

        # tag = 0
        # encode FP
        if self.have_fp:
            encoded_product_FP = torch.nn.functional.elu(self.dense_FP(product_fp))  # 1xd

        # encode product
        init_x = self.dense_init_nfeats(product_dgl.ndata['x'])  # nxd
        init_e = self.dense_init_efeats(product_dgl.edata['e'])  # Exd
        PlacedRC = self.Placed_rc_embedding(Placed_RcNodeIdx.type(torch.long))  # nxd
        x_add_PlacedRC = init_x + PlacedRC  # nxd
        product_nfeats, product_efeats = self.product_graph_encoder(graph=product_dgl, nfeats=x_add_PlacedRC,
                                                                    efeats=init_e)  # nxd, Exd

        # product_graph_level_embedding
        product_graph_level_node = torch.mean(product_nfeats, dim=0, keepdim=True)  # 1xd
        product_graph_level_edge = torch.mean(product_efeats, dim=0, keepdim=True)  # 1xd
        if self.have_fp:
            product_graph_level = torch.cat([product_graph_level_node, product_graph_level_edge, encoded_product_FP],
                                            dim=1)  # 1x3d
        else:
            product_graph_level = torch.cat([product_graph_level_node, product_graph_level_edge], dim=1)  # 1x2d
        product_graph_level = self.dense_product_graph_level(product_graph_level)  # 1xd
        product_graph_level = torch.nn.functional.elu(product_graph_level)  # 1xd

        if tag == 0:
            base0 = torch.cat([product_nfeats, product_graph_level], dim=0)  # (n+1)xd
            return base0  # (n+1)xd

        # tag = 1
        # placed_lg
        if len(LgIdx_list) == 0:
            placed_lg = self.init_placed_lg  # 1xd
        else:
            placed_lg = lg_embedding[LgIdx_list]
            placed_lg = torch.sum(placed_lg, dim=0, keepdim=True)  # 1xd
        base1 = torch.cat([product_graph_level, placed_lg], dim=1)  # 1x2d
        base1 = self.dense(base1)  # 1xd
        base1 = torch.nn.functional.elu(base1)  # 1xd
        return base1


class RetroA2CNet(torch.nn.Module):
    def __init__(self, hidden_dimension, num_egat_heads, num_egat_layers, lg_hypergraph_dhg, num_of_LayerHypergraph=1,
                 residual=True, have_fp=True, have_structure=True):
        super(RetroA2CNet, self).__init__()
        self.have_structure = have_structure
        self.lg_hypergraph_dhg = lg_hypergraph_dhg
        self.lg_emb = torch.nn.Parameter(torch.FloatTensor(lg_hypergraph_dhg.num_v, hidden_dimension))  # vxd
        torch.nn.init.xavier_normal_(self.lg_emb)
        if have_structure:
            self.hypergraph_encoder = HGNNPlusEncoder(
                layer_info=[hidden_dimension for _ in range(num_of_LayerHypergraph + 1)],
                drop_rate=0, res=False)

        self.base_encoder_for_oneState = BaseStateEncoderOneState(hidden_dimension=hidden_dimension,
                                                                  num_egat_heads=num_egat_heads,
                                                                  num_egat_layers=num_egat_layers,
                                                                  lg_hypergraph_dhg=lg_hypergraph_dhg,
                                                                  residual=residual, have_fp=have_fp)

        # policy tag = 0
        self.policy_dense0_tag0 = torch.nn.Linear(hidden_dimension, hidden_dimension)
        self.policy_dense1_tag0 = torch.nn.Linear(hidden_dimension, 1)
        # policy tag = 1
        self.policy_dense0_tag1 = torch.nn.Linear(hidden_dimension, hidden_dimension)
        self.policy_dense1_tag1 = torch.nn.Linear(hidden_dimension, lg_hypergraph_dhg.num_v + 1)
        # value tag = 0
        self.value_dense0_tag0 = torch.nn.Linear(hidden_dimension, hidden_dimension)
        self.value_dense1_tag0 = torch.nn.Linear(hidden_dimension, 1)
        # value tag = 1
        self.value_dense0_tag1 = torch.nn.Linear(hidden_dimension, hidden_dimension)
        self.value_dense1_tag1 = torch.nn.Linear(hidden_dimension, 1)

    def policy_for_one_state(self, state, lg_embedding=None, logits=False):
        if lg_embedding is None:
            if self.have_structure:
                lg_embedding = self.hypergraph_encoder(hg=self.lg_hypergraph_dhg, X=self.lg_emb)  # vxd
            else:
                lg_embedding = self.lg_emb  # vxd

        tag = state[5]
        mask = state[6]

        base_emb = self.base_encoder_for_oneState(state, lg_embedding)

        if tag == 0:
            # base_emb = (n+1)xd
            y = self.policy_dense0_tag0(base_emb)  # (n+1)xd
            y = torch.nn.functional.elu(y)  # (n+1)xd
            y = self.policy_dense1_tag0(y)  # (n+1)x1
            y = y.reshape(1, -1)  # 1x(n+1)
            # mask
            y = y + mask * -1e9  # 1x(n+1)
            if not logits:
                y = torch.nn.functional.softmax(y, dim=-1)  # 1x(n+1)
            return y  # 1x(n+1)

        elif tag == 1:
            # base_emb = 1xd
            y = self.policy_dense0_tag1(base_emb)  # 1xd
            y = torch.nn.functional.elu(y)  # 1xd
            y = self.policy_dense1_tag1(y)  # 1x(v+1)
            # mask
            y = y + mask * -1e9  # 1x(v+1)
            if not logits:
                y = torch.nn.functional.softmax(y, dim=-1)  # 1x(v+1)
            return y  # 1x(v+1)

    def value_for_one_state(self, state, lg_embedding):
        tag = state[5]

        base_emb = self.base_encoder_for_oneState(state, lg_embedding)

        if tag == 0:
            # base_emb = (n+1)xd
            base_emb = base_emb[-1:]  # 1xd
            y = self.value_dense0_tag0(base_emb)  # 1xd
            y = torch.nn.functional.elu(y)  # 1xd
            y = self.value_dense1_tag0(y)  # 1x1
            # y = torch.nn.functional.relu(y)  # 1x1
            return y  # 1x1

        elif tag == 1:
            # base_emb = 1xd
            y = self.value_dense0_tag1(base_emb)  # 1xd
            y = torch.nn.functional.elu(y)  # 1xd
            y = self.value_dense1_tag1(y)  # 1x1
            # y = torch.nn.functional.relu(y)  # 1x1
            return y

    def policy(self, list_of_state):
        if self.have_structure:
            lg_embedding = self.hypergraph_encoder(hg=self.lg_hypergraph_dhg, X=self.lg_emb)  # vxd
        else:
            lg_embedding = self.lg_emb  # vxd
        list_rst_policy = [self.policy_for_one_state(state=_state, lg_embedding=lg_embedding)
                           for _state in list_of_state]  # [1x(n+1), 1x(v+1), ...]
        return list_rst_policy

    def value(self, list_of_state):
        if self.have_structure:
            lg_embedding = self.hypergraph_encoder(hg=self.lg_hypergraph_dhg, X=self.lg_emb)  # vxd
        else:
            lg_embedding = self.lg_emb  # vxd
        list_rst_value = [self.value_for_one_state(state=_state, lg_embedding=lg_embedding)
                          for _state in list_of_state]  # [1x1, 1x1, 1x1, ...]
        return list_rst_value


