import time
import torch
import torch.nn as nn
import math
import pickle
import os
from datetime import datetime

import json

from tqdm import tqdm

from models.training import sinusoidal_embedding, get_beta_schedule


class MLPBlock(nn.Module):
    """MLP block with residual connection"""

    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, dim)
        )

    def forward(self, x):
        return x + self.net(x)  # Residual connection


class ConditionalResnetDiffusionModel(nn.Module):
    def __init__(
        self,
        input_dim=1,
        condition_dim=4,
        hidden_dim=64,
        num_blocks=4,
        beta_schedule_args=None,
        condition_embed=True,
    ):
        super().__init__()

        self.timesteps = beta_schedule_args["timesteps"]

        self.register_buffer("beta", get_beta_schedule(beta_schedule_args))
        self.register_buffer("alpha", 1.0 - self.beta)
        self.register_buffer("alpha_bar", torch.cumprod(self.alpha, dim=0))
        self.time_dim = 16  # Dimension of sinusoidal embedding
        self.input_dim = input_dim

        if condition_embed == True:
            self.conditioning_network = nn.Sequential(
                nn.Linear(1, 16),
                nn.ReLU(),
                nn.Linear(16, 32),
                nn.ReLU(),
                nn.Linear(32, condition_dim),
            )
            if condition_dim == 0:
                print("-" * 90)
                print("Conditioning dim is 0!")
                print("-" * 90)
        else:
            self.conditioning_network = nn.Identity()

        # Time embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(self.time_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # Initial projection
        self.input_proj = nn.Linear(input_dim, hidden_dim)

        # Combined dimension
        combined_dim = hidden_dim + hidden_dim + condition_dim

        # MLP blocks with residual connections
        self.blocks = nn.ModuleList(
            [MLPBlock(combined_dim, combined_dim * 2) for _ in range(num_blocks)]
        )

        # Final projection
        self.output_proj = nn.Sequential(
            nn.Linear(combined_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, input_dim),
        )

    def forward(self, x, c, t):
        # Time embedding
        t_emb = sinusoidal_embedding(t, self.time_dim)
        t_emb = self.time_mlp(t_emb)

        # Condition embedding
        c_emb = self.conditioning_network(c.unsqueeze(1))

        # Initial projection
        x_emb = self.input_proj(x)

        # Combine data with time embedding
        h = torch.cat([x_emb, c_emb, t_emb], dim=1)

        # Process through residual MLP blocks
        for block in self.blocks:
            h = block(h)

        # Final projection to output dimensionality
        output = self.output_proj(h)

        return output
