import torch.nn as nn
import torch.nn.functional as F
from .downstream_base import DownstreamModelBase
import math

import torch

from types import SimpleNamespace




class UNetAnomalyDetector(DownstreamModelBase):
    def _build_model(self):
        # Encoder
        self.envc1 = nn.Conv1d(self.n_channels, 64, kernel_size=3, padding=1)
        self.envc2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
        self.envc3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)

        self.pool = nn.MaxPool1d(kernel_size=2, stride=2) #

        # Decoder
        self.upconv2 = nn.ConvTranspose1d(256, 128, kernel_size=2, stride=2) 
        self.dec2 = nn.Conv1d(256, 128, kernel_size=3, padding=1)

        self.upconv1 = nn.ConvTranspose1d(128, 64, kernel_size=2, stride=2)
        self.dec1 = nn.Conv1d(128, 64, kernel_size=3, padding=1)

        # Final 1*1 convolution to restore original shape
        self.final_conv = nn.Conv1d(64, self.n_channels, kernel_size=1)
    
    def forward(self, batch_x, batch_f, batch_mask):
        """
        :param batch_x: (batch_size, n_channels, seq_len)
        :param batch_f: (batch_size, n_channels, n_features)
        :param batch_mask: (batch_size, seq_len)
        :return (batch_size, n_channels, seq_len)
        """
        # Encoder
        x1 = F.relu(self.envc1(batch_x)) # (batch, 64, seq_len)
        x2 = self.pool(x1) # Downsample (batch, 64, seq_len/2)

        x2 = F.relu(self.envc2(x2)) # (batch, 128, seq_len/2)
        x3 = self.pool(x2) # Downsample (batch, 128, seq_len/4)

        x3 = F.relu(self.envc3(x3)) # (batch, 256, seq_len/4)

        # Decoder
        x = self.upconv2(x3) # Upsample (batch, 128, seq_len/2)
        x = torch.cat([x, x2], dim=1) # Skip connection
        x = F.relu(self.dec2(x)) # (batch, 128, seq_len/2)

        x = self.upconv1(x) # Unsample (batch, 64, seq_len)
        x = torch.cat([x, x1], dim=1) # Skip connection
        x = F.relu(self.dec1(x)) # (batch, 64, seq_len)

        predicted_x = self.final_conv(x) # (batch, n_channels, seq_len) 
        return predicted_x


class VAEAnomalyDetector(DownstreamModelBase):
    """
    Conv1D-VAE for reconstruction-based anomaly detection.
      Input : x  [B, C, L]
      Output: x̂ [B, C, L]
    """
    def _build_model(self):
        args = self.downstream_args
        self.latent_dim = int(args.get("latent_dim", 64))
        self.hidden_c   = int(args.get("hidden_c", 128))   # bottleneck channels
        self.normalize  = bool(args.get("normalize", True))
        self.beta       = float(args.get("beta", 1.0))     # β-VAE weight

        # Encoder blocks: keep length with padding, then 2x MaxPool -> L/4
        self.enc1 = nn.Sequential(
            nn.Conv1d(self.n_channels, 64, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, 64), nn.ReLU(inplace=True),
            nn.Conv1d(64, 64, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, 64), nn.ReLU(inplace=True),
        )
        self.pool1 = nn.MaxPool1d(2)

        self.enc2 = nn.Sequential(
            nn.Conv1d(64, 128, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, 128), nn.ReLU(inplace=True),
            nn.Conv1d(128, 128, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, 128), nn.ReLU(inplace=True),
        )
        self.pool2 = nn.MaxPool1d(2)

        self.enc3 = nn.Sequential(
            nn.Conv1d(128, self.hidden_c, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, self.hidden_c), nn.ReLU(inplace=True),
        )

        # μ, logσ² over flattened bottleneck
        self.mu_head     = nn.Linear(self.hidden_c, self.latent_dim)
        self.logvar_head = nn.Linear(self.hidden_c, self.latent_dim)

        # Decoder: z → bottleneck feature (C=hidden_c, L=L/4), then upsample x2
        self.fc_dec = nn.Linear(self.latent_dim, self.hidden_c)

        self.up2 = nn.ConvTranspose1d(self.hidden_c, 128, kernel_size=2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv1d(128, 128, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, 128), nn.ReLU(inplace=True),
        )

        self.up1 = nn.ConvTranspose1d(128, 64, kernel_size=2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv1d(64, 64, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, 64), nn.ReLU(inplace=True),
        )

        self.final = nn.Conv1d(64, self.n_channels, kernel_size=1)

        self.eps = 1e-5
        self.last_kl = None

    @staticmethod
    def _pad_to_multiple_of(x, m: int):
        # pad last dim to multiple of m using replicate
        B, C, L = x.shape
        r = (m - (L % m)) % m
        if r == 0:
            return x, 0
        return F.pad(x, (0, r), mode="replicate"), r

    def _norm(self, x):
        if not self.normalize:
            B, C, _ = x.shape
            device, dtype = x.device, x.dtype
            mean = torch.zeros((B, C, 1), device=device, dtype=dtype)
            std  = torch.ones((B, C, 1),  device=device, dtype=dtype)
            return x, mean, std
        mean = x.mean(dim=-1, keepdim=True).detach()
        var  = x.var (dim=-1, keepdim=True, unbiased=False).detach()
        std  = torch.sqrt(var + self.eps)
        return (x - mean) / std, mean, std

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, batch_x, batch_f=None, batch_mask=None):
        """
        batch_x: [B, C, L]
        return : x_hat [B, C, L]
        side   : self.last_kl (scalar tensor on same device)
        """
        x = batch_x
        device = x.device

        # Normalize & pad to multiple of 4
        x_n, mean, std = self._norm(x)
        x_n, r = self._pad_to_multiple_of(x_n, 4)

        # Encoder (two downsamples → L/4), then global average over time for latent heads
        h1 = self.enc1(x_n)              # [B, 64, L]
        h1p = self.pool1(h1)             # [B, 64, L/2]
        h2 = self.enc2(h1p)              # [B, 128, L/2]
        h2p = self.pool2(h2)             # [B, 128, L/4]
        hb = self.enc3(h2p)              # [B, hidden_c, L/4]

        # Global pooling over time to get a sample-level latent
        hb_pool = hb.mean(dim=-1)        # [B, hidden_c]
        mu     = self.mu_head(hb_pool)       # [B, Z]
        logvar = self.logvar_head(hb_pool)   # [B, Z]
        z = self.reparameterize(mu, logvar)  # [B, Z]

        # KL divergence (per batch mean)
        # KL(N(μ,σ²) || N(0,1)) = -0.5 * Σ(1 + logσ² - μ² - σ²)
        kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)  # [B]
        self.last_kl = kl.mean()

        # Decoder: z → [B, hidden_c, L/4]
        db = self.fc_dec(z)                      # [B, hidden_c]
        db = db.unsqueeze(-1)                    # [B, hidden_c, 1]
        db = db.expand(-1, -1, hb.size(-1))      # tile to length L/4

        y = self.up2(db)                         # ~ [B, 128, L/2]
        if y.size(-1) != h2.size(-1):
            y = F.pad(y, (0, h2.size(-1) - y.size(-1)))
        y = self.dec2(y)

        y = self.up1(y)                          # ~ [B, 64, L]
        if y.size(-1) != h1.size(-1):
            y = F.pad(y, (0, h1.size(-1) - y.size(-1)))
        y = self.dec1(y)

        y = self.final(y)                        # [B, C, L]

        # Remove pad
        if r > 0:
            y = y[..., :-r]

        # De-normalize
        x_hat = y * std + mean

        # Optional mask: keep original where mask==0
        if batch_mask is not None:
            m = batch_mask.to(x_hat.dtype).unsqueeze(1)  # [B,1,L]
            x_hat = x_hat * m + x * (1 - m)

        return x_hat



AVAILABLE_ANOMALYORS = {
    "LSTM":LSTMAEAnomaly,
    "UNet":UNetAnomalyDetector,
    "VAE":VAEAnomalyDetector
}