import torch
import torch.nn as nn

from models.training import sinusoidal_embedding, get_beta_schedule


class ConditionalDiffusionModel(nn.Module):
    def __init__(
        self, input_dim, condition_dim, beta_schedule_args=None, layer_sizes=[128, 64]
    ):
        super().__init__()
        self.timesteps = beta_schedule_args["timesteps"]

        self.register_buffer("beta", get_beta_schedule(beta_schedule_args))
        self.register_buffer("alpha", 1.0 - self.beta)
        self.register_buffer("alpha_bar", torch.cumprod(self.alpha, dim=0))

        self.t_emb_dim = 4

        if condition_dim == 0:
            print("Conditioning dim is 0")

        self.t_embedding_layer = nn.Sequential(
            nn.Linear(
                self.t_emb_dim, 32
            ),  # Expand time embedding to a higher dimension
            nn.ReLU(),
            nn.Linear(32, self.t_emb_dim),  # Map back to the same dimension if needed
        )

        self.conditioning_network = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU(),
            nn.Linear(16, 32),
            nn.ReLU(),
            nn.Linear(32, condition_dim),
        )

        network_layers = []
        prev_size = (
            input_dim + condition_dim + self.t_emb_dim
        )  # Input size including condition and timestep

        for size in layer_sizes:
            network_layers.extend(
                [
                    nn.Linear(prev_size, size),
                    nn.ReLU(),
                ]
            )
            prev_size = size

        # Add final output layer
        network_layers.append(nn.Linear(prev_size, input_dim))

        self.network = nn.Sequential(*network_layers)

    def forward(self, x, c, t):
        t_emb = self.t_embedding_layer(sinusoidal_embedding(t, self.t_emb_dim))
        c_emb = self.conditioning_network(c.unsqueeze(1))
        x_input = torch.cat([x, c_emb, t_emb], dim=1)
        return self.network(x_input)
