import os 
import sys
from typing import List

import torch as th
import torch.nn.functional as F

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from model.discriminator import Discriminator


def vae_loss(recon: th.Tensor | None, input: th.Tensor | None, mu: th.Tensor | None, logvar: th.Tensor | None) -> List[th.Tensor]:
    if recon is None or input is None or mu is None or logvar is None:
        return 0, 0
    recon_loss = F.mse_loss(recon, input)
    kld_loss = th.mean(-0.5 * th.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0)
    return recon_loss, kld_loss

def gradient_penalty(disc: Discriminator, real_obs: th.Tensor, fake_obs: th.Tensor, real_next_obs: th.Tensor | None = None, fake_next_obs: th.Tensor | None = None) -> th.Tensor:
    # Random weight term for interpolation between real and fake samples
    alpha = th.rand(real_obs.size(0), 1, device=real_obs.device)
    alpha = alpha.expand(real_obs.size())

    # Get random interpolation between real and fake data
    interpolates = (alpha * real_obs + ((1 - alpha) * fake_obs)).requires_grad_(True)

    if real_next_obs is not None and fake_next_obs is not None:
        interpolates_next = (alpha * real_next_obs + ((1 - alpha) * fake_next_obs)).requires_grad_(True)
        disc_interpolates = disc.forward(interpolates, interpolates_next)
    else:
        disc_interpolates = disc.forward(interpolates)

    # Calculate gradients of interpolates with respect to disc scores
    gradients = th.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                  grad_outputs=th.ones(disc_interpolates.size(), device=real_obs.device),
                                  create_graph=True, retain_graph=True)[0]

    # Calculate the gradient penalty
    gradients = gradients.view(gradients.size(0), -1)
    gradients_norm = th.sqrt(th.sum(gradients ** 2, dim=1) + 1e-12)
    gradient_penalty = ((gradients_norm - 1) ** 2).mean()
    
    return gradient_penalty

def kl_div(mean1: th.Tensor, logvar1: th.Tensor, mean2: th.Tensor, logvar2: th.Tensor) -> th.Tensor:
    kl_divergence = 0.5 * (
        logvar2 - logvar1 + (th.exp(logvar1) + (mean1 - mean2).pow(2)) / (th.exp(logvar2) + 1e-8) - 1
    )
    return kl_divergence.mean()