import numpy as np
import ot
import torch
import torch.distributions as TD


class DiscreteEOT_l2sq_sampler:

    @staticmethod
    def discrete_sample_conditional(Y, G, i_x, n_pts, return_indices=False):
        probs = G[i_x] / torch.sum(G[i_x])
        distrib = TD.Categorical(probs=probs)
        numbers = distrib.sample((n_pts,))
        if not return_indices:
            return Y[numbers]
        return numbers

    def __init__(self, X, Y, G, device="cpu"):
        self.device = device
        self.X = torch.tensor(X).float().clone().detach().to(self.device)
        self.Y = torch.tensor(Y).float().clone().detach().to(self.device)
        self.G = torch.tensor(G).float().clone().detach().to(self.device)

    def sample(self, x_samples):
        raise NotImplementedError()

    def sample_by_indices(self, x_indices, return_indices=False):
        spls = []
        for x_idx in x_indices:
            spls.append(self.discrete_sample_conditional(self.Y, self.G, x_idx, 1, return_indices=return_indices))
        return torch.cat(spls, dim=0)

    def sample_by_index(self, x_index, n, return_indices=False):
        return self.discrete_sample_conditional(self.Y, self.G, x_index, n, return_indices=return_indices)


def store_discrete_ot(path, model):
    data = {
        "X": model.X.detach().cpu(),
        "Y": model.Y.detach().cpu(),
        "G": model.G.detach().cpu(),
    }
    torch.save(data, path)


def load_discrete_ot(path, device="cpu"):
    CP = torch.load(path)
    return DiscreteEOT_l2sq_sampler(CP["X"], CP["Y"], CP["G"], device=device)


class DiscreteEOT_l2sq:

    def _cast(self, x):
        if self.dtype == "torch32":
            return torch.tensor(x).float().clone().detach().to(self.device)
        if self.dtype == "torch64":
            return torch.tensor(x).double().clone().detach().to(self.device)

    def __init__(
        self,
        verbose=False,
        method="sinkhorn_log",
        stopThr=1e-09,
        numItermax=10000,
        dtype="torch32",
        device="cpu",
    ):
        self.verbose = verbose
        self.method = method
        self.stopThr = stopThr
        self.numItermax = numItermax
        self.dtype = dtype
        self.device = device

    def solve(self, X, Y, eps):
        _X, _Y = self._cast(X), self._cast(Y)
        M = 0.5 * ot.dist(_X, _Y)
        xL, yL = X.shape[0], Y.shape[0]
        wX, wY = self._cast(np.ones(xL) / xL), self._cast(np.ones(yL) / yL)
        G = ot.sinkhorn(
            wX, wY, M, eps, method=self.method, numItermax=self.numItermax, stopThr=self.stopThr, verbose=self.verbose
        )
        return DiscreteEOT_l2sq_sampler(_X, _Y, G, device=self.device)
