import timeit
import torch.nn.functional as F
import torch
import numpy as np
from sklearn.metrics import adjusted_rand_score
import unittest

from zmq import device

from slot_attention.losses.ari import adjusted_rand_index

class TestAdjustedRandIndex(unittest.TestCase):
    def setUp(self):
        self.adjusted_rand_index = adjusted_rand_index

    def test_perfect_match(self):
        print('test_perfect_match')
        true_mask = np.array([0, 0, 1, 1, 2, 2])
        pred_mask = np.array([1, 1, 2, 2, 0, 0])
        true_mask_oh = F.one_hot(torch.tensor(true_mask), num_classes=3).to(torch.float32)
        pred_mask_oh = F.one_hot(torch.tensor(pred_mask), num_classes=3).to(torch.float32)

        ari_torch = self.adjusted_rand_index(true_mask_oh.unsqueeze(0), pred_mask_oh.unsqueeze(0), ignore_background=False).numpy()
        ari_sklearn = adjusted_rand_score(true_mask, pred_mask)

        print(f'ari_torch: {ari_torch}')
        print(f'ari_sklearn: {ari_sklearn}')

        np.testing.assert_almost_equal(ari_torch, ari_sklearn, decimal=4)
        
    def test_no_perfect_match(self):
        print('test_no_perfect_match')
        true_mask = np.array([0, 0, 0, 1, 2, 2])
        pred_mask = np.array([1, 2, 2, 2, 0, 0])
        true_mask_oh = F.one_hot(torch.tensor(true_mask), num_classes=3).to(torch.float32)
        pred_mask_oh = F.one_hot(torch.tensor(pred_mask), num_classes=3).to(torch.float32)

        ari_torch = self.adjusted_rand_index(true_mask_oh.unsqueeze(0), pred_mask_oh.unsqueeze(0), ignore_background=False).numpy()
        ari_sklearn = adjusted_rand_score(true_mask, pred_mask)

        print(f'ari_torch: {ari_torch}')
        print(f'ari_sklearn: {ari_sklearn}')

        np.testing.assert_almost_equal(ari_torch, ari_sklearn, decimal=4)

    def test_no_match(self):
        print('test_no_match')
        true_mask = np.array([0, 1, 2, 0, 1, 2])
        pred_mask = np.array([2, 2, 0, 0, 1, 1])
        true_mask_oh = F.one_hot(torch.tensor(true_mask), num_classes=3).to(torch.float32)
        pred_mask_oh = F.one_hot(torch.tensor(pred_mask), num_classes=3).to(torch.float32)

        ari_torch = self.adjusted_rand_index(true_mask_oh.unsqueeze(0), pred_mask_oh.unsqueeze(0), ignore_background=False).numpy()
        ari_sklearn = adjusted_rand_score(true_mask, pred_mask)

        print(f'ari_torch: {ari_torch}')
        print(f'ari_sklearn: {ari_sklearn}')

        np.testing.assert_almost_equal(ari_torch, ari_sklearn, decimal=4)

    def test_no_perfect_match_nd(self):
        print('test_no_perfect_match_nd')
        true_mask = np.array([
            [0, 0, 0, 1, 2, 2],
            [0, 1, 0, 1, 2, 2]])
        pred_mask = np.array([[1, 2, 2, 2, 0, 0],
                              [1, 2, 2, 2, 0, 0]])
        true_mask_oh = F.one_hot(torch.tensor(true_mask), num_classes=3).to(torch.float32)
        pred_mask_oh = F.one_hot(torch.tensor(pred_mask), num_classes=3).to(torch.float32)

        ari_torch = self.adjusted_rand_index(true_mask_oh, pred_mask_oh, ignore_background=False).numpy()
        ari_sklearn_0 = adjusted_rand_score(true_mask[0], pred_mask[1])
        ari_sklearn_1 = adjusted_rand_score(true_mask[1], pred_mask[1])
        ari_sklearn = np.stack([ari_sklearn_0, ari_sklearn_1])

        print(f'ari_torch: {ari_torch}')
        print(f'ari_sklearn: {ari_sklearn}')

        np.testing.assert_almost_equal(ari_torch, ari_sklearn, decimal=4)
        
    def test_no_perfect_match_nd_big(self):
        print('timeit on test_no_perfect_match_nd_big')
        true_mask = np.random.randint(0, 3, size=(64, 6))
        pred_mask = np.random.randint(0, 3, size=(64, 6))
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f'device: {device}')
        true_mask_oh = F.one_hot(torch.tensor(true_mask), num_classes=3).to(torch.float32).to(device)
        pred_mask_oh = F.one_hot(torch.tensor(pred_mask), num_classes=3).to(torch.float32).to(device)

        ari_torch = self.adjusted_rand_index(true_mask_oh, pred_mask_oh, ignore_background=False)

        t = timeit.timeit(lambda: self.adjusted_rand_index(true_mask_oh, pred_mask_oh, ignore_background=False), number=500)
        print(f'needed seconds for ari_torch: {t:.3f}')


if __name__ == '__main__':
    unittest.main()