import torch

from torch.nn import Module, Sequential, Embedding, Linear, ReLU, LogSoftmax
from typing import List, TypeVar

from ae_rw_diffusion.unet import UNet


# Create TypeVars for FloatTensor and LongTensor
FloatTensor = TypeVar('FloatTensor', torch.FloatTensor, torch.cuda.FloatTensor)
LongTensor = TypeVar('LongTensor', torch.LongTensor, torch.cuda.LongTensor)


class UNetAdapter(Module):
    """
    Network that works as an adapter for the U-Net by projecting down the one-hot inputs to lower dimensional embedding
    space before applying the U-Net.
    """
    def __init__(self, hidden_channels: List[int], num_nodes: int, node_embedding_dim: int,
                 num_time_steps: int, time_embedding_dim: int, num_res_blocks: int = 1, kernel_size: int = 3) -> None:
        """
        Initialization of UNetAdapter.

        Args:
            hidden_channels: List[int]
                Hidden channels for every level of the U-Net.
            num_nodes: int (optional, default: None)
                Number of nodes in the graph.
            node_embedding_dim: int
                Dimensionality of the node embeddings.
            num_time_steps: int
                Total number of time steps.
            time_embedding_dim: int
                Dimensionality of the time step embeddings.
            num_res_blocks: int (optional, default: 1)
                Number of residual blocks in a row inside ContractingBlock and ExpansiveBlock.
            kernel_size: int (optional, default: 3)
                Kernel size of the 1D convolution.
        """
        super().__init__()

        # Initialize U-Net
        self.unet = UNet(node_embedding_dim, hidden_channels, num_time_steps, time_embedding_dim,
                         num_res_blocks=num_res_blocks, kernel_size=kernel_size)

        # Initialize node embeddings for all nodes and the mask state
        self.node_embeddings = Embedding(num_nodes + 1, node_embedding_dim)

        # Initialize ReLU and a linear layer
        self.linear = Sequential(
            ReLU(),
            Linear(node_embedding_dim, num_nodes)
        )

        # Initialize the log-softmax layer
        self.log_softmax = LogSoftmax(dim=2)

    def forward(self, x: LongTensor, time_steps: LongTensor) -> FloatTensor:
        """
        Forward pass of UNet.

        Args:
            x: LongTensor, shape: (batch_size, walk_length)
                Input tensor, where walk_length corresponds to in_channels.
            time_steps: LongTensor, shape: (batch_size,)
                Time steps.

        Returns:
            x: FloatTensor, shape: (batch_size, walk_length, num_nodes)
        """
        # Obtain node embeddings for the random walks
        x = self.node_embeddings(x)

        # Transpose dimensions 1 and 2
        x = torch.transpose(x, 1, 2)

        # Put x with time_steps through the U-Net
        x = self.unet(x, time_steps)

        # Transpose dimensions 1 and 2
        x = torch.transpose(x, 1, 2)

        # Use a linear layer to project the logits up to the number of nodes again
        x = self.linear(x)

        # Compute softmax
        x = self.log_softmax(x)

        return x
