from dataclasses import dataclass, field, MISSING
from typing import List

import torch
import torch.nn as nn

from .layers import LlamaRMSNorm


@dataclass
class VAEConfig:
    input_dim: int = 256
    # hidden_dims: List[int] = field(default_factory=lambda: [1024, 2048, 1024, 512])
    hidden_dims: List[int] = field(default_factory=lambda: [256, 256])
    latent_dim: int = 256


class VAEEncoder(nn.Module):
    def __init__(self, input_dim: int, hidden_dims: list[int], latent_dim: int):
        super(VAEEncoder, self).__init__()
        # [256, 512, 256, 128, 64]
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        hidden_dims = [input_dim] + hidden_dims
        modules = []
        for i in range(1, len(hidden_dims)):
            h_dim1 = hidden_dims[i - 1]
            h_dim2 = hidden_dims[i]
            modules.extend(
                [
                    nn.Linear(h_dim1, h_dim2),
                    LlamaRMSNorm(h_dim2),
                    nn.LeakyReLU(),
                ]
            )
        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        x = self.encoder(x)
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        return mu, log_var

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        mu = self.fc_mu(x)
        return mu


class VAEDecoder(nn.Module):
    def __init__(self, input_dim: int, hidden_dims: list[int], latent_dim: int):
        super(VAEDecoder, self).__init__()
        self.latent_dim = latent_dim
        hidden_dims.reverse()
        hidden_dims = [latent_dim] + hidden_dims
        modules = []
        for i in range(1, len(hidden_dims)):
            h_dim1 = hidden_dims[i - 1]
            h_dim2 = hidden_dims[i]
            modules.extend(
                [
                    nn.Linear(h_dim1, h_dim2),
                    LlamaRMSNorm(h_dim2),
                    nn.LeakyReLU(),
                ]
            )
        modules.append(nn.Linear(hidden_dims[-1], input_dim))
        self.decoder = nn.Sequential(*modules)

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        return self.decoder(z)


class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.input_dim = encoder.input_dim
        self.latent_dim = encoder.latent_dim

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std


def vae_loss(x_hat, x, mu, logvar):
    # loss = standard_vae_loss(x_hat, x, mu, logvar) # 1
    loss = beta_vae_loss(x_hat, x, mu, logvar)  # 2
    # loss = annealed_vae_loss(x_hat, x, mu, logvar, kl_weight=0.1) # 3
    # loss = focal_vae_loss(x_hat, x, mu, logvar, gamma=2.0) # 4 Never use
    # loss = mmd_vae_loss(x_hat, x, mu) # 5
    # loss = wasserstein_vae_loss(x_hat, x, mu, logvar)  # 6
    # loss = info_vae_loss(x_hat, x, mu, logvar, alpha=1.0, lambda_=0.01) # 7
    # loss = binary_vae_loss(x_hat, x, mu, logvar, beta=1.0) # 8
    # loss = exact_reconstruction_vae_loss(x_hat, x, mu, logvar, beta=1.0, gamma=10.0) # 9
    return loss


def clip_loss(mu1: torch.Tensor, mu2: torch.Tensor, temp=0.1):
    mu1 = nn.functional.normalize(mu1, p=2, dim=-1)
    mu1 = nn.functional.normalize(mu2, p=2, dim=-1)

    sim = (
        torch.matmul(mu1.squeeze(1), mu2.squeeze(1).transpose(0, 1)) / temp
    )  # [bsz, bsz]

    # Calculate both directions loss
    logits1 = nn.functional.log_softmax(sim, dim=1)
    loss1 = -logits1.diag().mean()

    logits2 = nn.functional.log_softmax(sim, dim=0)
    loss2 = -logits2.diag().mean()

    return (loss1 + loss2) / 2


def siamese_loss(x_hat1, x1, mu1, logvar1, x_hat2, x2, mu2, logvar2, alpha=1.0):
    vae_loss1 = vae_loss(x_hat1, x1, mu1, logvar1)
    vae_loss2 = vae_loss(x_hat2, x2, mu2, logvar2)

    tot_vae_loss = (vae_loss1 + vae_loss2) / (2 * x1.size(0))

    return tot_vae_loss + alpha * clip_loss(mu1, mu2)


def standard_vae_loss(x_hat, x, mu, logvar):
    bce_loss = nn.functional.binary_cross_entropy_with_logits(x_hat, x, reduction="sum")
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return bce_loss + kl_loss


def beta_vae_loss(x_hat, x, mu, logvar, beta=10.0):
    """
    β-VAE loss with controllable regularization strength.
    Higher beta = stronger regularization, better disentanglement.
    Lower beta = better reconstruction.
    """
    bce_loss = nn.functional.binary_cross_entropy_with_logits(x_hat, x, reduction="sum")
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return bce_loss + beta * kl_loss


def annealed_vae_loss(x_hat, x, mu, logvar, kl_weight):
    """
    Annealed VAE loss where KL weight increases during training.
    Start with kl_weight near 0, gradually increase to 1.
    """
    bce_loss = nn.functional.binary_cross_entropy_with_logits(x_hat, x, reduction="sum")
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return bce_loss + kl_weight * kl_loss


def focal_vae_loss(x_hat, x, mu, logvar, gamma=2.0):
    """
    Focal loss for reconstruction + KL divergence.
    Helps when truth tables have imbalanced bit distributions.
    """
    # Convert logits to probabilities
    probs = torch.sigmoid(x_hat)

    # Focal loss calculation
    pt = torch.where(x == 1, probs, 1 - probs)
    focal_weight = (1 - pt).pow(gamma)

    # Binary cross entropy
    bce = nn.functional.binary_cross_entropy_with_logits(x_hat, x, reduction="none")

    # Apply focal weights
    focal_loss = focal_weight * bce

    # Sum reduction
    focal_loss = focal_loss.sum()

    # KL divergence
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return focal_loss + kl_loss


def mmd_vae_loss(x_hat, x, z):
    """
    MMD-VAE using Maximum Mean Discrepancy instead of KL divergence.
    Better for generating synthetic truth tables with similar distribution.

    Args:
        x_hat: Reconstructed output
        x: Original input
        z: Latent vector (not mu/logvar)
    """
    # Reconstruction loss
    bce_loss = nn.functional.binary_cross_entropy_with_logits(x_hat, x, reduction="sum")

    # Sample from prior
    prior_z = torch.randn_like(z)

    # Compute MMD
    batch_size = z.size(0)

    # Kernel for z
    k_z = compute_rbf_kernel(z, z)
    k_prior = compute_rbf_kernel(prior_z, prior_z)
    k_cross = compute_rbf_kernel(z, prior_z)

    # MMD calculation
    mmd_loss = k_z.mean() + k_prior.mean() - 2 * k_cross.mean()

    return bce_loss + mmd_loss


def compute_rbf_kernel(x1, x2, bandwidth=None):
    """Compute RBF kernel for MMD calculation"""
    x1 = x1.squeeze(1)
    x2 = x2.squeeze(1)
    x1_squared = x1.pow(2).sum(1).unsqueeze(1)
    x2_squared = x2.pow(2).sum(1).unsqueeze(0)

    # Compute squared distance matrix
    squared_dist = x1_squared + x2_squared - 2 * torch.matmul(x1, x2.transpose(0, 1))

    # Default bandwidth (median heuristic)
    if bandwidth is None:
        bandwidth = torch.median(squared_dist.detach())
        bandwidth = torch.sqrt(bandwidth / 2.0)

    # Apply RBF kernel
    return torch.exp(-squared_dist / (2 * bandwidth * bandwidth))


def wasserstein_vae_loss(x_hat, x, mu, logvar):
    """
    Wasserstein VAE loss (also known as WAE).
    Helps create smoother latent spaces.
    """
    # Reconstruction loss
    bce_loss = nn.functional.binary_cross_entropy_with_logits(x_hat, x, reduction="sum")

    # Sample z
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    z = mu + eps * std

    # MMD with Gaussian kernel
    prior_z = torch.randn_like(z)
    mmd_loss = compute_mmd(z, prior_z)

    return bce_loss + mmd_loss


def compute_mmd(x, y, alpha=1.0):
    """Compute Maximum Mean Discrepancy with RBF kernel"""
    B = x.size(0)
    x = x.squeeze(1)
    y = y.squeeze(1)
    xx = torch.matmul(x, x.T)
    yy = torch.matmul(y, y.T)
    xy = torch.matmul(x, y.T)

    rx = xx.diag().unsqueeze(0).expand_as(xx)
    ry = yy.diag().unsqueeze(0).expand_as(yy)

    K = torch.exp(-alpha * (rx.t() + rx - 2 * xx))
    L = torch.exp(-alpha * (ry.t() + ry - 2 * yy))
    P = torch.exp(-alpha * (rx.t() + ry - 2 * xy))

    beta = 1.0 / (B * (B - 1))
    gamma = 2.0 / (B * B)

    return beta * (torch.sum(K) + torch.sum(L)) - gamma * torch.sum(P)

    # RBF kernel with multiple bandwidths
    bandwidths = [0.01, 0.1, 1, 10, 100]

    xx_kernel = 0
    yy_kernel = 0
    xy_kernel = 0

    for bandwidth in bandwidths:
        xx_kernel += torch.exp(
            -(xx.diag().unsqueeze(1) - 2 * xx + xx.diag().unsqueeze(0)) / bandwidth
        )
        yy_kernel += torch.exp(
            -(yy.diag().unsqueeze(1) - 2 * yy + yy.diag().unsqueeze(0)) / bandwidth
        )
        xy_kernel += torch.exp(
            -torch.matmul(
                (x**2).sum(1).unsqueeze(1), torch.ones_like(y[:, 0]).unsqueeze(0)
            )
            - 2 * xy
            + torch.matmul(
                torch.ones_like(x[:, 0]).unsqueeze(1), (y**2).sum(1).unsqueeze(0)
            )
            / bandwidth
        )
        print(xx_kernel.mean(), yy_kernel.mean(), xy_kernel.mean())

    return xx_kernel.mean() + yy_kernel.mean() - 2 * xy_kernel.mean()


def info_vae_loss(x_hat, x, mu, logvar, alpha=1.0, lambda_=0.01):
    """
    InfoVAE loss - better for capturing function semantics.

    Args:
        alpha: Weight of MI maximization
        lambda_: Weight of divergence from prior
    """
    batch_size = x.size(0)

    # Reconstruction term
    bce_loss = nn.functional.binary_cross_entropy_with_logits(x_hat, x, reduction="sum")

    # Weighted KL divergence
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    # Sample z
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    z = mu + eps * std

    # MMD term for divergence from prior
    prior_z = torch.randn_like(z)
    mmd_loss = compute_mmd(z, prior_z)

    return bce_loss + (1 - alpha) * kl_loss + lambda_ * batch_size * mmd_loss


def binary_vae_loss(x_hat, x, mu, logvar, beta=1.0):
    """
    VAE loss specifically designed for binary truth tables.
    Includes a term to encourage sharper 0/1 decisions.
    """
    # Standard BCE loss
    bce_loss = nn.functional.binary_cross_entropy_with_logits(x_hat, x, reduction="sum")

    # KL divergence
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    # Sharpness term - penalizes outputs close to 0.5
    probs = torch.sigmoid(x_hat)
    sharpness_loss = torch.sum((probs * (1 - probs)).mean(dim=1))

    return bce_loss + beta * kl_loss + 0.1 * sharpness_loss


def exact_reconstruction_vae_loss(x_hat, x, mu, logvar, beta=1.0, gamma=10.0):
    """
    VAE loss that strongly penalizes any bit flips in the truth table.
    Good when exact functionality preservation is critical.
    """
    # Get binary predictions
    binary_preds = (torch.sigmoid(x_hat) > 0.5).float()

    # Compute exact match penalty - strongly penalize any incorrect bits
    bit_errors = (binary_preds != x).float().sum(dim=1)
    exact_penalty = gamma * bit_errors

    # Standard BCE loss
    bce_loss = nn.functional.binary_cross_entropy_with_logits(x_hat, x, reduction="sum")

    # KL divergence
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return bce_loss + exact_penalty.sum() + beta * kl_loss
