import glob
from types import SimpleNamespace
import torch
import torch.nn as nn
import math
import pickle
import os
from datetime import datetime
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
from mixed_diffusion.conditioning.archetype_conditioning import ArchetypeConditioning


class SelfAttention(nn.Module):
    """Multi-head self-attention module"""

    def __init__(self, dim, num_heads=4, head_dim=None, dropout=0.0):
        super().__init__()
        self.num_heads = num_heads
        head_dim = head_dim or dim // num_heads
        inner_dim = head_dim * num_heads

        self.scale = head_dim**-0.5
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))

    def forward(self, x):
        b, n, d = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(
            lambda t: t.reshape(b, n, self.num_heads, -1).transpose(1, 2), qkv
        )

        attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = attn.softmax(dim=-1)

        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).reshape(b, n, -1)
        return self.to_out(out)


class FeedForward(nn.Module):
    """Feed-forward network with GELU activation"""

    def __init__(self, dim, hidden_dim, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class TransformerBlock(nn.Module):
    """Transformer block with pre-norm"""

    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = SelfAttention(dim, num_heads=num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.ff = FeedForward(dim, int(dim * mlp_ratio), dropout=dropout)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x


class TabularDiffusionTransformer(nn.Module):
    def __init__(
        self,
        args,
        input_dim=1,
        hidden_dim=64,
        num_blocks=4,
        use_archetype_conditioning=False,
        num_archetypes=None,
        condition_dim=64,
    ):
        super().__init__()
        self.noise_step = args.noise_step
        self.time_dim = 16  # Dimension of sinusoidal embedding
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_blocks = num_blocks
        self.num_heads = 4  # Number of attention heads
        self.use_archetype_conditioning = use_archetype_conditioning

        # Time embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(self.time_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # Initial projection
        print(f"Input dimension: {input_dim}")
        print(f"Hidden dimension: {hidden_dim}")
        self.input_proj = nn.Linear(input_dim, hidden_dim)

        # Archetype conditioning setup
        if use_archetype_conditioning and num_archetypes is not None:
            self.archetype_conditioning = ArchetypeConditioning(
                num_archetypes=num_archetypes,
                condition_dim=condition_dim,
                dropout_prob=getattr(args, "archetype_dropout_prob", 0.15),
                device=getattr(args, "device", "cuda"),
            )
            # Add condition token to sequence
            self.sequence_length = 3  # feature + time + condition
        else:
            self.archetype_conditioning = None
            self.sequence_length = 2  # feature + time

        # Position embedding for transformer
        self.pos_embedding = nn.Parameter(
            torch.zeros(1, self.sequence_length, hidden_dim)
        )

        # Time embedding projection
        self.time_proj = nn.Linear(hidden_dim, hidden_dim)

        # Condition embedding projection (if using archetype conditioning)
        if use_archetype_conditioning:
            self.condition_proj = nn.Linear(condition_dim, hidden_dim)

        # Transformer blocks
        self.blocks = nn.ModuleList(
            [
                TransformerBlock(
                    dim=hidden_dim, num_heads=self.num_heads, mlp_ratio=4.0, dropout=0.1
                )
                for _ in range(num_blocks)
            ]
        )

        # Layer norm
        self.norm = nn.LayerNorm(hidden_dim)

        # Final projection
        self.output_proj = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, input_dim),
        )

    def _get_sinusoidal_embeddings(self, timesteps):
        """Generate sinusoidal time embeddings"""
        half_dim = self.time_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb)
        emb = timesteps.float()[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if self.time_dim % 2 == 1:
            emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
        return emb

    def forward(self, x, t, condition=None, archetype_labels=None):
        # Ensure x is properly shaped
        batch_size = x.shape[0]

        # Reshape x if needed
        if len(x.shape) > 2:
            x_reshaped = x.reshape(batch_size, -1)
        else:
            x_reshaped = x  # Already in the correct shape (batch_size, input_dim)

        # Time embedding
        t_emb = self._get_sinusoidal_embeddings(t)
        t_emb = self.time_mlp(t_emb)

        # Initial projection
        x_projected = self.input_proj(x_reshaped)

        # Create sequence tokens
        x_token = x_projected.unsqueeze(1)  # [batch, 1, hidden_dim]
        t_token = self.time_proj(t_emb).unsqueeze(1)  # [batch, 1, hidden_dim]

        # Build sequence based on conditioning type
        if self.use_archetype_conditioning and archetype_labels is not None:
            # Use archetype conditioning
            condition_emb = self.archetype_conditioning(
                archetype_labels, training=self.training
            )
            c_token = self.condition_proj(condition_emb).unsqueeze(
                1
            )  # [batch, 1, hidden_dim]
            sequence = torch.cat(
                [x_token, t_token, c_token], dim=1
            )  # [batch, 3, hidden_dim]
        elif condition is not None:
            # Use traditional conditioning (project to hidden_dim)
            if condition.shape[-1] != self.hidden_dim:
                condition_proj = nn.Linear(condition.shape[-1], self.hidden_dim).to(
                    condition.device
                )
                c_token = condition_proj(condition).unsqueeze(1)
            else:
                c_token = condition.unsqueeze(1)
            sequence = torch.cat(
                [x_token, t_token, c_token], dim=1
            )  # [batch, 3, hidden_dim]
        else:
            # No conditioning
            sequence = torch.cat([x_token, t_token], dim=1)  # [batch, 2, hidden_dim]

        # Add positional embeddings
        sequence = sequence + self.pos_embedding

        # Process through transformer blocks
        for block in self.blocks:
            sequence = block(sequence)

        # Apply layer norm
        sequence = self.norm(sequence)

        # Extract feature token (first token in sequence)
        output_features = sequence[:, 0]

        # Final projection to output dimensionality
        output = self.output_proj(output_features)

        # Reshape output to match input shape
        if len(x.shape) > 2:
            return output.reshape_as(x)
        else:
            return output

    def apply_cfg(self, x, t, archetype_labels, cfg_scale=1.5):
        """
        Apply Classifier-Free Guidance (CFG) during sampling.

        Args:
            x: Noisy input x_t
            t: Timestep t
            archetype_labels: Target archetype labels
            cfg_scale: CFG scale w (typically 1.5-3.0)

        Returns:
            CFG-guided prediction ε̂
        """
        if not self.use_archetype_conditioning:
            raise ValueError("CFG requires archetype conditioning to be enabled")

        return self.archetype_conditioning.apply_cfg(
            x, t, self, archetype_labels, cfg_scale
        )
