import numpy as np
import torch
import torch.nn as nn

import unittest

class TimeMLP(nn.Module):
    '''
    naive introduce timestep information to feature maps with mlp and add shortcut
    '''

    def __init__(self, embedding_dim, hidden_dim, out_dim):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(embedding_dim, out_dim),
                                 #nn.SiLU(),
                                 #nn.Linear(hidden_dim, out_dim)
                                 )
        #self.act = nn.SiLU()
        self.act = nn.Tanh()

    def forward(self, x, t):
        t_emb = self.mlp(t).unsqueeze(-1).unsqueeze(-1)
        x = x + t_emb

        return self.act(x)


class LinearDenoising(nn.Module):
    """
    Simple linear denoising model for DDPM using MLPs.
    Always requires a timestep embedding.
    """

    def __init__(self, timesteps: int, time_embedding_dim: int, data_shape: int, hidden_dim: int = 10):
        """
        Initialize the LinearDenoising model.

        :param timesteps: Number of timesteps in DDPM.
        :param time_embedding_dim: Dimension of the time embeddings.
        :param data_shape: Flattened shape of the input data (e.g., 28*28 for MNIST).
        :param hidden_dim: Dimension of hidden layers in the MLP.
        """
        super(LinearDenoising, self).__init__()

        self.timesteps = timesteps
        self.time_embedding_dim = time_embedding_dim
        self.data_shape = data_shape  # e.g., 784 for MNIST

        # Time embedding
        self.time_embedding = nn.Embedding(timesteps, time_embedding_dim)

        # Define the denoising MLP
        self.denoise = nn.Sequential(
            nn.Linear(self.data_shape + time_embedding_dim, self.data_shape),
            #nn.Tanh(),
            #nn.Linear(hidden_dim, hidden_dim),
            #nn.SiLU(),
            #nn.Linear(hidden_dim, hidden_dim),
            #nn.SiLU(),
            #nn.Linear(hidden_dim, self.data_shape)
        )

        self.activation = nn.Tanh()

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Forward pass to predict noise.

        :param x: Input data of shape (batch_size, data_shape).
        :param t: Timesteps of shape (batch_size,). Must be provided.
        :return: Predicted noise of shape (batch_size, data_shape).
        """
        if t is None:
            raise ValueError("Timestep t must be provided.")

        # Retrieve time embeddings
        t_emb = self.time_embedding(t)  # (batch_size, time_embedding_dim)

        # Concatenate input with time embeddings
        x = torch.cat([x, t_emb], dim=1)  # (batch_size, data_shape + time_embedding_dim)

        # Predict noise
        noise_pred = self.denoise(x)  # (batch_size, data_shape)

        # Apply activation
        noise_pred = self.activation(noise_pred)

        return noise_pred



class MultilayerLinearDenoising(nn.Module):
    """
    Simple linear denoising model for DDPM using MLPs.
    Always requires a timestep embedding.
    """

    def __init__(self, timesteps: int, time_embedding_dim: int, data_shape: int, hidden_dim: int = 10):
        """
        Initialize the LinearDenoising model.

        :param timesteps: Number of timesteps in DDPM.
        :param time_embedding_dim: Dimension of the time embeddings.
        :param data_shape: Flattened shape of the input data (e.g., 28*28 for MNIST).
        :param hidden_dim: Dimension of hidden layers in the MLP.
        """
        super(MultilayerLinearDenoising, self).__init__()

        self.timesteps = timesteps
        self.time_embedding_dim = time_embedding_dim
        self.data_shape = data_shape  # e.g., 784 for MNIST

        # Time embedding
        self.fc1 = nn.Linear(data_shape + 1, hidden_dim)  # Input layer (2D x + 1D t)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, hidden_dim)
        self.fc5 = nn.Linear(hidden_dim, data_shape)  # Output layer (assuming 2D output)

        # Activation function
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Forward pass to predict noise.

        :param x: Input data of shape (batch_size, data_shape).
        :param t: Timesteps of shape (batch_size,). Must be provided.
        :return: Predicted noise of shape (batch_size, data_shape).
        """
        x_t = torch.cat((x, t.unsqueeze(1)), dim=1)  # Concatenate along the feature dimension

        # Forward pass through the network
        x_t = self.relu(self.fc1(x_t))
        x_t = self.relu(self.fc2(x_t))
        x_t = self.relu(self.fc3(x_t))
        x_t = self.relu(self.fc4(x_t))
        x_out = self.fc5(x_t)  # Final output

        return x_out



if __name__ == '__main__':
    timesteps = 1000
    time_embedding_dim = 128
    image_shape = 6  # MNIST images
    hidden_dim = 10
    batch_size = 4

    model = LinearDenoising(
        timesteps=timesteps,
        time_embedding_dim=time_embedding_dim,
        image_shape=image_shape,
        hidden_dim=hidden_dim
    )

    output = model(np.array([1,2,3,4,5,6]), t=[0])