from typing import Any, Sequence, Callable
from abc import ABC, abstractmethod

import torch
from torch import nn, Tensor
from torch import distributions as D

import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import matplotlib.pyplot as plt
from tqdm import trange
from geomloss import SamplesLoss
import torch.nn.functional as F


def mmd_metric(traj_samples, traj_data):
#     """
#     Assume that traj_samples and traj_data share a common time grid,
#     i.e. both of shape [*, T, D]. The first dimension may differ...

#     Note that SamplesLoss expects point clouds of shape [N, D]
#     where N = number of data points, D = dimensionality of each point.
#     """
     mmd = SamplesLoss("energy")
     N = traj_samples.shape[0]
     return mmd(traj_samples.reshape(N,-1), traj_data.reshape(N,-1)).item()

#def mmd_metric(traj_samples, traj_data):
#    mmd = SamplesLoss("energy")
#    traj_samples = torch.as_tensor(traj_samples, dtype=torch.float32)
#    traj_data = torch.as_tensor(traj_data, dtype=torch.float32)
#    return mmd(traj_samples, traj_data).item()


def sinkhorn_dist(traj_samples, traj_data):
     sinkhorn = SamplesLoss(loss="sinkhorn", p=2, blur=0.05)
     N = traj_samples.shape[0]
     return sinkhorn(traj_samples.reshape(N,-1), traj_data.reshape(N,-1)).item()


def solve_sde(
        sde: Callable[[Tensor, Tensor], tuple[Tensor, Tensor]],
        z: Tensor,
        ts: float,
        tf: float,
        n_steps: int
) -> Tensor:
    tt = torch.linspace(ts, tf, n_steps + 1)[:-1]
    dt = (tf - ts) / n_steps
    dt_2 = abs(dt) ** 0.5

    path = [z]
    for t in tt:
        f, g = sde(z, t)
        w = torch.randn_like(z)
        z = z + f * dt + g * w * dt_2

        path.append(z)

    return torch.stack(path)

def jvp(f: Callable[[Tensor], ...], x: Tensor, v: Tensor) -> tuple[Tensor, ...]:
    return torch.autograd.functional.jvp(
        f, x, v,
        create_graph=torch.is_grad_enabled()
    )

def t_dir(f: Callable[[Tensor], ...], t: Tensor) -> tuple[Tensor, ...]:
    return jvp(f, t, torch.ones_like(t))

def grad(f: Callable[[Tensor], ...], x: Tensor) -> tuple[Tensor, Tensor]:
    create_graph = torch.is_grad_enabled()

    with torch.enable_grad():
        x = x.clone()

        if not x.requires_grad:
            x.requires_grad = True

        y = f(x)

        (gradient, ) = torch.autograd.grad(y.sum(), x, create_graph=create_graph)

    return y, gradient

class SDE(nn.Module, ABC):
    @abstractmethod
    def drift(self, z: Tensor, t: Tensor, *args: Any) -> Tensor:
        raise NotImplementedError

    @abstractmethod
    def vol(self, z: Tensor, t: Tensor, *args: Any) -> Tensor:
        raise NotImplementedError

    def forward(self, z: Tensor, t: Tensor, *args: Any) -> tuple[Tensor, Tensor]:
        drift = self.drift(z, t, *args)
        vol = self.vol(z, t, *args)
        return drift, vol

class PriorSDE(SDE):
    def __init__(self, latent_size: int, hidden_size: int):
        super().__init__()

        self.drift_net = nn.Sequential(
            nn.Linear(latent_size, hidden_size),
            nn.Softplus(),
            nn.Linear(hidden_size, hidden_size),
            nn.Softplus(),
            nn.Linear(hidden_size, latent_size),
        )

        self.vol_nets = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(1, hidden_size),
                    nn.Softplus(),
                    nn.Linear(hidden_size, 1),
                    nn.Sigmoid()
                )
                for _ in range(latent_size)
            ]
        )

    def drift(self, z: Tensor, t: Tensor, *args) -> Tensor:
        return self.drift_net(z)

    def vol(self, z: Tensor, t: Tensor, *args) -> Tensor:
        z = torch.split(z, 1, dim=1)
        g = [net_i(z_i) for net_i, z_i in zip(self.vol_nets, z)]
        return torch.cat(g, dim=1)

class PosteriorEncoder(nn.Module):
    #def __init__(self, input_size: int, hidden_size: int):
    #    super().__init__()
        #self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, batch_first=True)
    def __init__(self, data_dim: int, hidden_size: int):
        super().__init__()
        self.gru = nn.GRU(input_size=data_dim + 1, hidden_size=hidden_size, batch_first=True)


    #def forward(self, x: Tensor) -> Tensor:
    #    out, h = self.gru(x)
    def forward(self, x: Tensor, t: Tensor) -> Tensor:
        # x: (batch, seq, data_dim)
        # t: (batch, seq, 1)
        xt = torch.cat([x, t], dim=-1)   # (batch, seq, data_dim+1)
        out, h = self.gru(xt)
        return torch.cat([h[0, :, None], out], dim=1)

class PosteriorAffine(nn.Module):
    def __init__(self, latent_size: int, hidden_size: int):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(hidden_size + 1, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * latent_size),
        )
        self.sm = nn.Softmax(dim=-1)

    def get_coeffs(self, ctx: Tensor, t: Tensor) -> tuple[Tensor, Tensor]:
        l = ctx.shape[1] - 1

        h, out = ctx[:, 0], ctx[:, 1:]
        #ts = torch.linspace(0, 1, l)[None, :]
        ts = torch.linspace(0, 1, l, device=t.device)[None, :]
        c = self.sm(-(l * (ts - t)) ** 2)
        out = (out * c[:, :, None]).sum(dim=1)
        ctx_t = torch.cat([h + out, t], dim=1)

        m, log_s = self.net(ctx_t).chunk(chunks=2, dim=1)
        s = torch.exp(log_s)

        return m, s

    def forward(
            self,
            ctx: Tensor,
            t: Tensor,
            return_t_dir: bool = False
    ) -> tuple[Tensor, Tensor] | tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]:
        if return_t_dir:
            def f(t_in: Tensor) -> Tensor:
                return self.get_coeffs(ctx, t_in)

            return t_dir(f, t)
        else:
            return self.get_coeffs(ctx, t)

class MatchingSDE(nn.Module):
    def __init__(
            self,
            p_init_distr: nn.Module,
            p_sde: SDE,
            p_observe: nn.Module,
            q_enc: nn.Module,
            q_affine: nn.Module
    ):
        super().__init__()

        self.p_init_distr = p_init_distr
        self.p_sde = p_sde
        self.p_observe = p_observe
        self.q_enc = q_enc
        self.q_affine = q_affine

    def loss_prior(self, ctx: Tensor) -> Tensor:
        bs = ctx.shape[0]

        #t0 = torch.zeros(bs, 1)
        t0 = torch.zeros(bs, 1, device=ctx.device)

        
        m0, s0 = self.q_affine(ctx, t0)
        q_z0 = D.Independent(D.Normal(m0, s0), 1)

        p_z0 = self.p_init_distr()

        loss_prior = D.kl_divergence(q_z0, p_z0)

        return loss_prior

    def loss_diff(self, ctx: Tensor, t: Tensor) -> Tensor:
        (m, s), (dm, ds) = self.q_affine(ctx, t, return_t_dir=True)

        eps = torch.randn_like(m)
        z = m + s * eps

        def g2_in(z_in):
            return self.p_sde.vol(z_in, t) ** 2

        g2, d_g2 = grad(g2_in, z)

        q_dz = dm + ds * eps
        q_score = - eps / s
        q_drift = q_dz + 0.5 * g2 * q_score + 0.5 * d_g2

        p_drift = self.p_sde.drift(z, t)

        loss_diff = 0.5 * (q_drift - p_drift) ** 2 / g2
        loss_diff = loss_diff.sum(dim=1)

        return loss_diff

    def loss_recon(self, ctx: Tensor, x: Tensor, t: Tensor) -> Tensor:
        m, s = self.q_affine(ctx, t)

        eps = torch.randn_like(m)
        z = m + s * eps

        p_x = self.p_observe(z)

        loss_recon = -p_x.log_prob(x)

        return loss_recon

    def forward(self, xs: Tensor, ts: Tensor) -> Tensor:
        bs = xs.shape[0]
        n = xs.shape[1]

        #ctx = self.q_enc(xs)
        ctx = self.q_enc(xs, ts)


        # prior loss
        loss_prior = self.loss_prior(ctx)

        # diffusion loss
        #t = torch.rand(bs, 1) * (ts[:, -1] - ts[:, 0]) + ts[:, 0]
        t = torch.rand(bs, 1, device=ts.device) * (ts[:, -1] - ts[:, 0]) + ts[:, 0]

        
        loss_diff = self.loss_diff(ctx, t)

        # reconstruction loss
        rng = torch.arange(bs)
        u = torch.randint(n, [bs])
        t_u = ts[rng, u]
        x_u = xs[rng, u]

        loss_recon = self.loss_recon(ctx, x_u, t_u)

        # full loss
        loss = loss_prior + loss_diff + loss_recon

        #return loss
        return loss, loss_prior, loss_diff, loss_recon
