import torch
from torch import Tensor
import unittest
from ..datasets.uniform_rand_coords import (
    UniformRandCoords,
    UniformRandCoordsWithSimulation,
)
from timeit import timeit


class TestUniformRandCoords(unittest.TestCase):
    # def test_simple(self):
    #     import smlmpyplot as splt

    #     ds = UniformRandCoordsWithSimulation(length=1, density=10.0, n_frames=3, seed=0)
    #     x, y = ds[0]

    #     splt.imshow(y.mean(dim=0), img_extent=2 * [[0, 6400]])
    #     splt.gt_coordinates_show(x)
    #     splt.legend()
    #     splt.savefig()

    # def test_gpu(self, N=10):
    #     for device in ["cpu", "cuda:0"]:
    #         ds = UniformRandCoords(
    #             length=N, density=10.0, n_frames=3, seed=0, device=device
    #         )

    #         def fun():
    #             [_ for _ in ds]

    #         print(device, timeit(fun, number=10))

    def test_density(self, N=1000):
        s = 6400**2 * 1e-6
        for d_gt in [0.2, 2.0]:
            ds = UniformRandCoords(length=N, density=d_gt, n_frames=3, seed=0)
            n = [(x[:, 4] > 0.0).sum() for x in ds]
            n = torch.stack(n).float().mean()
            d = n / s
            print(d, d_gt)


if __name__ == "__main__":
    unittest.main()
