import numpy as np
import scipy.stats as stats
import sys
from tqdm import tqdm

import ot

sys.path.append("..")
import utils


def wasserstein(x, y, n_projections=100):
    return ot.sliced_wasserstein_distance(x, y, n_projections=n_projections).item()


def calc_wds(replay, test_pos, n_projections=10_000):
    # treating time as a dimension
    assert test_pos.ndim == 4 and test_pos.shape[-1] == 2
    wds = {}
    t_lim = test_pos.shape[2]  # seeds, N, T
    for k, v in tqdm(replay.items()):
        # each v has shape (seeds, N, T, 2)
        wds[k] = wasserstein(
            test_pos.flatten(0, 1).flatten(1),
            v[:, :, :t_lim].flatten(0, 1).flatten(1),
            n_projections=n_projections,
        )
    return wds
