from typing import Any, Sequence, Callable
from abc import ABC, abstractmethod

import torch
from torch import nn, Tensor
from torch import distributions as D

import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import matplotlib.pyplot as plt
from tqdm import trange
from geomloss import SamplesLoss
import torch.nn.functional as F

class PriorInitDistribution(nn.Module):
    def __init__(self, latent_size: int, log_s_init: float = -2.5):
        super().__init__()

        #self.m = nn.Parameter(torch.zeros(1, latent_size))
        #self.log_s = nn.Parameter(torch.zeros(1, latent_size))
        #Start in the right initial distr.
        self.m = nn.Parameter(torch.ones(1, latent_size))   # start at 1
        self.log_s = nn.Parameter(torch.full((1, latent_size), log_s_init))  

    def forward(self) -> D.Distribution:
        m = self.m
        s = torch.exp(self.log_s)
        return D.Independent(D.Normal(m, s), 1)


# -------------------------------
# Identity Observation
# -------------------------------
class IdentityObservation(nn.Module):
    def __init__(self, eps: float = 0.05):
        super().__init__()
        self.eps = eps

    def get_coeffs(self, z: Tensor) -> tuple[Tensor, Tensor]:
        m = z
        s = torch.ones_like(m) * self.eps
        return m, s

    def forward(self, z: Tensor):
        m, s = self.get_coeffs(z)
        return torch.distributions.Independent(
            torch.distributions.Normal(m, s),
            1
        )