import glob
from types import SimpleNamespace
import torch
import torch.nn as nn
import math
import pickle
import os
from datetime import datetime
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
from mixed_diffusion.conditioning.archetype_conditioning import ArchetypeConditioning


class MLPBlock(nn.Module):
    """MLP block with residual connection"""

    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, dim),
            nn.BatchNorm1d(dim),
        )

    def forward(self, x):
        return x + self.net(x)  # Residual connection


class TabularDiffusionMLP(nn.Module):
    def __init__(
        self,
        config,
        input_dim=1,
        hidden_dim=64,
        condition_dim=0,
        num_blocks=4,
        use_archetype_conditioning=False,
        num_archetypes=None,
    ):
        super().__init__()
        self.noise_step = config["noise_step"]
        self.time_dim = 16  # Dimension of sinusoidal embedding
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_blocks = num_blocks
        self.use_archetype_conditioning = use_archetype_conditioning

        # Time embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(self.time_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # Initial projection
        print(f"Input dimension: {input_dim}")
        print(f"Hidden dimension: {hidden_dim}")
        self.input_proj = nn.Linear(input_dim, hidden_dim)

        # Archetype conditioning setup
        if use_archetype_conditioning and num_archetypes is not None:
            self.archetype_conditioning = ArchetypeConditioning(
                num_archetypes=num_archetypes,
                condition_dim=condition_dim,
                dropout_prob=config.get("archetype_dropout_prob", 0.15),
                device="mps",
            )
            # Override condition_dim with archetype conditioning dimension
            actual_condition_dim = condition_dim
        else:
            self.archetype_conditioning = None
            actual_condition_dim = condition_dim

        # Combined dimension (data + time embedding + condition)
        combined_dim = hidden_dim + hidden_dim + actual_condition_dim

        # MLP blocks with residual connections
        self.blocks = nn.ModuleList(
            [MLPBlock(combined_dim, combined_dim * 2) for _ in range(num_blocks)]
        )

        # Final projection
        self.output_proj = nn.Sequential(
            nn.Linear(combined_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, input_dim),
        )

    def _get_sinusoidal_embeddings(self, timesteps):
        """Generate sinusoidal time embeddings"""
        half_dim = self.time_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb)
        emb = timesteps.float()[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)

        if self.time_dim % 2 == 1:
            emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)

        return emb

    def forward(self, x, t, condition=None, archetype_labels=None):
        # Ensure x is properly shaped
        batch_size = x.shape[0]

        # Reshape x if needed
        if len(x.shape) > 2:
            x_reshaped = x.reshape(batch_size, -1)
        else:
            x_reshaped = x  # Already in the correct shape (batch_size, input_dim)

        # Time embedding
        t_emb = self._get_sinusoidal_embeddings(t)
        t_emb = self.time_mlp(t_emb)

        # Initial projection
        x_projected = self.input_proj(x_reshaped)

        # Handle conditioning
        if self.use_archetype_conditioning and archetype_labels is not None:
            # Use archetype conditioning
            condition_emb = self.archetype_conditioning(
                archetype_labels, training=self.training
            )
            h = torch.cat([x_projected, t_emb, condition_emb], dim=1)
        elif condition is not None:
            # Use traditional conditioning
            h = torch.cat([x_projected, t_emb, condition], dim=1)
        else:
            # No conditioning
            h = torch.cat([x_projected, t_emb], dim=1)

        # Process through residual MLP blocks
        for block in self.blocks:
            h = block(h)

        # Final projection to output dimensionality
        output = self.output_proj(h)

        # Reshape output to match input shape
        if len(x.shape) > 2:
            return output.reshape_as(x)
        else:
            return output

    def apply_cfg(self, x, t, archetype_labels, cfg_scale=1.5):
        """
        Apply Classifier-Free Guidance (CFG) during sampling.

        Args:
            x: Noisy input x_t
            t: Timestep t
            archetype_labels: Target archetype labels
            cfg_scale: CFG scale w (typically 1.5-3.0)

        Returns:
            CFG-guided prediction ε̂
        """
        if not self.use_archetype_conditioning:
            raise ValueError("CFG requires archetype conditioning to be enabled")

        return self.archetype_conditioning.apply_cfg(
            x, t, self, archetype_labels, cfg_scale
        )
