import copy
import torch
from torch import nn, Tensor
from typing import Sequence, Union, List
from meta_diffusion.model.embedder import get_embedder_by_task
from meta_diffusion.model.out_layer import get_out_layer_by_task
from meta_diffusion.model.encoder.gnn_layer import GNNSparseBlock, GNNDenseBlock


class GNNEncoder(nn.Module):
    def __init__(
        self,
        sparse: bool,
        shared_block_layers: Sequence[int],
        separate_block_layers: Sequence[int],
        time_flag: bool = True,
        hidden_dim: int = 256, 
        aggregation: str = "sum", 
        norm: str = "layer",
        learn_norm: bool = True, 
        track_norm: bool = False,
        task: List[str] = ["MIS", "MCl", "MCut"],
    ):
        super(GNNEncoder, self).__init__()
        
        # embedder and out_layer
        self.task_pool = task
        self.shared_block_num = len(shared_block_layers)
        self.separate_block_num = len(separate_block_layers)

        for task in self.task_pool:
            self.add_module(f"{task}_embed", get_embedder_by_task(task)(hidden_dim, sparse, time_flag))
            self.add_module(f"{task}_out", get_out_layer_by_task(task)(hidden_dim, 2, sparse))

        # time embedder between blocks
        self.hidden_dim = hidden_dim
        self.node_time_layers_1 = nn.ModuleList([
            nn.Sequential(
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ) for _ in range(self.shared_block_num)
        ])
        self.edge_time_layers_1 = nn.ModuleList([
            nn.Sequential(
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ) for _ in range(self.shared_block_num)
        ])
        self.node_time_layers_2 = nn.ModuleList([
            nn.Sequential(
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ) for _ in range(self.separate_block_num)
        ])
        self.edge_time_layers_2 = nn.ModuleList([
            nn.Sequential(
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ) for _ in range(self.separate_block_num)
        ])

        self.asym = ("ATSP" in self.task_pool)
            
        # gnn blocks
        if sparse:
            # gnn sparse blocks
            self.shared_blocks = nn.ModuleList([
                GNNSparseBlock(
                    num_layers=num_layers,
                    hidden_dim=hidden_dim,
                    aggregation=aggregation,
                    norm=norm,
                    learn_norm=learn_norm,
                    track_norm=track_norm,
                    asym=self.asym,
                ) for num_layers in shared_block_layers
            ])

            for task in self.task_pool:
                self.add_module(
                    name=f"{task}_blocks",
                    module=nn.ModuleList([
                        GNNSparseBlock(
                            num_layers=num_layers,
                            hidden_dim=hidden_dim,
                            aggregation=aggregation,
                            norm=norm,
                            learn_norm=learn_norm,
                            track_norm=track_norm,
                            asym=self.asym,
                        ) for num_layers in separate_block_layers
                    ])
                )
        else:
            # gnn dense blocks
            self.shared_blocks = nn.ModuleList([
                GNNDenseBlock(
                    num_layers=num_layers,
                    hidden_dim=hidden_dim,
                    aggregation=aggregation,
                    norm=norm,
                    learn_norm=learn_norm,
                    track_norm=track_norm,
                    asym=self.asym,
                ) for num_layers in shared_block_layers
            ])

            for task in self.task_pool:
                self.add_module(
                    name=f"{task}_blocks",
                    module=nn.ModuleList([
                        GNNDenseBlock(
                            num_layers=num_layers,
                            hidden_dim=hidden_dim,
                            aggregation=aggregation,
                            norm=norm,
                            learn_norm=learn_norm,
                            track_norm=track_norm,
                            asym=self.asym,
                        ) for num_layers in separate_block_layers
                    ])
                )

    def forward(
        self, task: str, focus_on_node: bool, focus_on_edge: bool, 
        nodes_feature: Tensor, x: Tensor, t: Tensor, edges_feature: Tensor, 
        e: Tensor, edge_index: Tensor
    ) -> Sequence[Tensor]:
        if self.asym:
            return self.asym_forward(
                task=task, focus_on_node=focus_on_node, focus_on_edge=focus_on_edge,
                nodes_feature=nodes_feature, x=x, edges_feature=edges_feature,
                e=e, t=t, edge_index=edge_index
            )
        else:
            return self.sym_forward(
                task=task, focus_on_node=focus_on_node, focus_on_edge=focus_on_edge,
                nodes_feature=nodes_feature, x=x, edges_feature=edges_feature,
                e=e, t=t, edge_index=edge_index
            )
        
    def asym_forward(
        self, task: str, focus_on_node: bool, focus_on_edge: bool, 
        nodes_feature: Tensor, x: Tensor, edges_feature: Tensor, 
        e: Tensor, t: Tensor, edge_index: Tensor
    ) -> Sequence[Tensor]:
        # nodes number
        nodes_num = None if x is None else x.shape[0]
        
        # embedder (task-specific)
        embedder = getattr(self, f"{task}_embed")
        x, e, d, t = embedder(nodes_feature, x, edges_feature, e, t)
        
        # gnn blocks (shared)
        for gnn_block, node_time_layer, edge_time_layer in zip(
            self.shared_blocks, self.node_time_layers_1, self.edge_time_layers_1
        ):
            gnn_block: Union[GNNDenseBlock, GNNSparseBlock]
            x, e = gnn_block.asym_forward(
                x=x, e=e, d=d,
                edge_index=edge_index,
                edges_feature=edges_feature, 
                nodes_num=nodes_num
            )
            if focus_on_node:
                raise NotImplementedError()
            if focus_on_edge and t is not None:
                e = e + edge_time_layer(t)
        
        # gnn blocks (task-specific)
        for gnn_block, node_time_layer, edge_time_layer in zip(
            getattr(self, f"{task}_blocks"),
            self.node_time_layers_2,
            self.edge_time_layers_2,
            # self.dist_layers_2
        ):
            gnn_block: Union[GNNDenseBlock, GNNSparseBlock]
            x, e = gnn_block.asym_forward(
                x=x, e=e, d=d,
                edge_index=edge_index, 
                edges_feature=edges_feature, 
                nodes_num=nodes_num
            )
            if focus_on_node:
                raise NotImplementedError()
            if focus_on_edge and t is not None:
                e = e + edge_time_layer(t)

        # out layer (task-specific)
        out_layer = getattr(self, f"{task}_out")
        x, e = out_layer(x, e)
        return x, e
    
    def sym_forward(
        self, task: str, focus_on_node: bool, focus_on_edge: bool, 
        nodes_feature: Tensor, x: Tensor, edges_feature: Tensor, 
        e: Tensor, t: Tensor, edge_index: Tensor
    ) -> Sequence[Tensor]:
        # embedder
        embedder = getattr(self, f"{task}_embed")
        x, e, t = embedder(nodes_feature, x, edges_feature, e, t)

        # gnn blocks (shared)
        for gnn_block, node_time_layer, edge_time_layer in zip(
            self.shared_blocks, self.node_time_layers_1, self.edge_time_layers_1
        ):
            gnn_block: Union[GNNDenseBlock, GNNSparseBlock]
            x, e = gnn_block.forward(x=x, e=e, edge_index=edge_index)
            if focus_on_node and t is not None:
                x = x + node_time_layer(t)
            if focus_on_edge and t is not None:
                e = e + edge_time_layer(t)

        # gnn blocks (shared)
        for gnn_block, node_time_layer, edge_time_layer in zip(
            getattr(self, f"{task}_blocks"),
            self.node_time_layers_2, 
            self.edge_time_layers_2
        ):
            gnn_block: Union[GNNDenseBlock, GNNSparseBlock]
            x, e = gnn_block.forward(x=x, e=e, edge_index=edge_index)
            if focus_on_node and t is not None:
                x = x + node_time_layer(t)
            if focus_on_edge and t is not None:
                e = e + edge_time_layer(t)

        # out layer
        out_layer = getattr(self, f"{task}_out")
        x, e = out_layer(x, e)

        return x, e