import torch
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt

from copy import deepcopy
from tqdm.auto import trange
from scipy.stats import norm
from tqdm import tqdm

from abc import ABC, abstractmethod

from dataclasses import dataclass

@dataclass
class WGFConfig:
    """
    Wasserstein Gradient Flow configuration
    """
    p: float = 2.0
    n_epochs: int = 100
    n_steps: int = 100 # Steps for the full gradient flow
    step: int = 0
    lr: float = 1.0
    tau: float = 0.01 # Time step
    n_projections: int = 1000
    divergence: callable = None
    target_measure: list = None
    # K = 10 # Number of steps for the WGF
    sinkhorn_epsilon: float = None
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    functional: callable = None

class WGF(ABC):
    """
    Wasserstein Gradient Flows
    """
    @abstractmethod
    def next_measure(self, current_measure):
        pass

    @abstractmethod
    def all_measures(self, initial_measure): 
        pass

# =============================================================================
# Sliced Wasserstein Gradient Flows
# Credit: Clément Bonet.
# Source: https://github.com/clbonet/Sliced-Wasserstein_Gradient_Flows

def emd1D(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True):
    n = u_values.shape[-1]
    m = v_values.shape[-1]

    device = u_values.device
    dtype = u_values.dtype

    if u_weights is None:
        u_weights = torch.full((n,), 1/n, dtype=dtype, device=device)

    if v_weights is None:
        v_weights = torch.full((m,), 1/m, dtype=dtype, device=device)

    if require_sort:
        # Sort u,v support points
        u_values, u_sorter = torch.sort(u_values, -1)
        v_values, v_sorter = torch.sort(v_values, -1)

        # Sort u,v weights according to the order of the support points
        # *_sorter is the indices of the sorted array
        u_weights = u_weights[..., u_sorter]
        v_weights = v_weights[..., v_sorter]

    # zero = torch.zeros(1, dtype=dtype, device=device)
    
    u_cdf = torch.cumsum(u_weights, -1)
    v_cdf = torch.cumsum(v_weights, -1)

    # Combine and sort the cdf of u and v on support points
    cdf_axis, _ = torch.sort(torch.cat((u_cdf, v_cdf), -1), -1)
    
    u_index = torch.searchsorted(u_cdf, cdf_axis)
    v_index = torch.searchsorted(v_cdf, cdf_axis)

    u_icdf = torch.gather(u_values, -1, u_index.clip(0, n-1))
    v_icdf = torch.gather(v_values, -1, v_index.clip(0, m-1))

    cdf_axis = torch.nn.functional.pad(cdf_axis, (1, 0))
    delta = cdf_axis[..., 1:] - cdf_axis[..., :-1]

    if p == 1:
        return torch.sum(delta * torch.abs(u_icdf - v_icdf), axis=-1)
    if p == 2:
        return torch.sum(delta * torch.square(u_icdf - v_icdf), axis=-1)  
    return torch.sum(delta * torch.pow(torch.abs(u_icdf - v_icdf), p), axis=-1)


def sliced_cost(Xs, Xt, projections=None,u_weights=None,v_weights=None,p=1):
    if projections is not None:
        Xps = (Xs @ projections).T
        Xpt = (Xt @ projections).T
    else:
        Xps = Xs.T
        Xpt = Xt.T

    return torch.mean(emd1D(Xps,Xpt,
                       u_weights=u_weights,
                       v_weights=v_weights,
                       p=p))


def sliced_wasserstein(Xs, Xt, num_projections, device,
                       u_weights=None, v_weights=None, p=1):
    num_features = Xs.shape[1]

    # Random projection directions, shape (num_features, num_projections)
    projections = np.random.normal(size=(num_features, num_projections))
    projections = F.normalize(torch.from_numpy(projections), p=2, dim=0).type(Xs.dtype).to(device)

    return sliced_cost(Xs,Xt,projections=projections,
                       u_weights=u_weights,
                       v_weights=v_weights,
                       p=p)


class SWGF(WGF): 

    def __init__(self, WGFConfig):
        self.c = WGFConfig
        self.measures = []
    
    def next_measure(self, current_x):
        if len(current_x) == 0:
            return current_x.detach()
        device = self.c.device
        _, d = current_x.size()
        tau = self.c.tau
        _F = self.c.functional

        current_x = current_x.to(device)
        # next_x = deepcopy(current_x).requires_grad_(True).to(device)
        next_x = current_x.clone().detach().requires_grad_(True).to(device)
        # next_x = torch.tensor(current_x.clone().detach(), requires_grad=True, device=device)

        optimizer = torch.optim.SGD([next_x], lr=self.c.lr, momentum=0.9)

        # train_loss = []
        # w_loss = []
        # J_loss = []

        # for _ in tqdm(range(self.c.n_epochs), desc="Computing next measure", leave=False,  ascii="░▒█"):
        for _ in range(self.c.n_epochs):
            optimizer.zero_grad()
            if d > 1:
                sw = sliced_wasserstein(next_x, current_x, self.c.n_projections, device, p=2)
            else:
                sw = emd1D(next_x.reshape(1,-1), current_x.reshape(1,-1), p=2)            

            # f = J(next_x)
            loss = sw + 2 * tau * _F(next_x)  # The proximal mapping

            loss.backward()
            # print(next_x.grad)
            optimizer.step()

        return next_x.detach()

    def all_measures(self, initial_x):
        if len(self.measures) == 0:
            current_x = initial_x
            self.measures.append(current_x)
            for _ in tqdm(range(self.c.n_steps), desc="Computing full gradient flow", ascii="░▒█"):
                current_x = self.next_measure(current_x)
                self.measures.append(current_x.cpu())
        return self.measures

# =============================================================================
# Wasserstein Gradient Flows with Target Measure
# Credit: Feydy et al., Genevay et al. and ott-jax developer
# Source: https://github.com/ott-jax/ott

from typing import Any, Callable, Tuple

import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
from IPython import display

import ott
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
from ott.tools import plot, sinkhorn_divergence

def reg_ot_cost(x, y, epsilon=None):
    """Return the OT cost and OT output given a geometry"""
    geom = pointcloud.PointCloud(x, y, epsilon=epsilon)
    ot = sinkhorn.Sinkhorn()(linear_problem.LinearProblem(geom))
    return ot.reg_ot_cost, ot

def sink_div(x, y, epsilon=None):
    """Return the Sinkhorn divergence cost and OT output given point clouds.

    Since ``y`` is fixed, we can use the flag ``static_b=True`` to avoid
    computing the ``reg_ot_cost(y, y)`` term.
    """
    ot = sinkhorn_divergence.sinkhorn_divergence(
        pointcloud.PointCloud,
        x=x,
        y=y,
        epsilon=epsilon,
        static_b=True,
    )
    return ot.divergence, ot

class WGFwithTarget(WGF):
    """
    Wasserstein Gradient Flows with Target Measure Version 1
    """

    def __init__(self, WGFConfig):
        self.c = WGFConfig
        self.divergence_vg = jax.jit(jax.value_and_grad(self.c.divergence, has_aux=True))
        self.target_measure = jnp.array([(b,d) for (b,d) in self.c.target_measure])
        # self.ots = []
        self.measures = []

    def next_measure(self, current_x):
        # print(current_x)
        current_x = jnp.array([(b,d) for (b,d) in current_x])
        for _ in tqdm(range(self.c.n_epochs), desc="Computing next measure", leave=False,  ascii="░▒█"):
            (cost, ot), grad_x = self.divergence_vg(current_x, self.target_measure, self.c.sinkhorn_epsilon)
            current_x = current_x - grad_x * self.c.lr
        # print(current_x)
        return current_x
    
    def all_measures(self, initial_x): 
        if len(self.measures) == 0:
            x = initial_x
            self.measures.append(x)
            for _ in tqdm(range(self.c.n_steps), desc="Computing full gradient flow", ascii="░▒█"):
                x = self.next_measure(x)
                self.measures.append(x)
        return self.measures


# =============================================================================
# WGF with Target Measure Version 2

from ott.geometry import pointcloud
from ott.solvers import linear

class WGFSinkhorn(WGF):
    """
    Wasserstein Gradient Flows with Target Measure
    """

    def __init__(self, WGFConfig):
        self.c = WGFConfig
        self.target_measure = jnp.array([(b,d) for (b,d) in self.c.target_measure])
        self.measures = []

    def next_measure(self, current_x):
        McCann_t = 1.0/(self.c.n_steps - self.c.step)
        current_x = jnp.array([(b,d) for (b,d) in current_x])
        geom = pointcloud.PointCloud(current_x, self.target_measure, cost_fn=None)
        solve_fn = jax.jit(linear.solve)
        # print(geom)
        ot = solve_fn(geom)
        # print(ot)
        x1 = jnp.dot(ot.matrix, self.target_measure) / jnp.sum(ot.matrix, axis=1, keepdims=True)
        # print(x1)
        target_x = current_x - McCann_t * (current_x - x1)
        # print(target_x)
        return target_x
    
    def all_measures(self, initial_x):
        if len(self.measures) != 0:
            return self.measures
        x = initial_x
        self.measures.append(x)
        for k in tqdm(range(self.c.n_steps), desc="Computing full gradient flow", ascii="░▒█"):
            self.c.step = k
            x = self.next_measure(x)
            self.measures.append(x)
        return self.measures

