import unittest

import torch as t

from hypo_interp.utils.test_utils import gaussian_kernel, hsic, permutation_test


class TestPermutationTest(unittest.TestCase):
    def setUp(self):
        # Example setup. Adjust according to your requirements.
        self.X = t.randn(50, 10)
        self.Y = t.randn(50, 10)
        self.Y_diff_shape = t.randn(60, 10)
        self.X_1d = t.randn(50)
        self.X_3d = t.randn(50, 10, 2)
        self.sigma = 1.0

    def test_proper_shape(self):
        result = permutation_test(self.X, self.Y)
        self.assertIsInstance(result, dict)
        self.assertIn("hsic", result)
        self.assertIn("p_value", result)
        self.assertIn("simulated_statistics", result)

    def test_gaussian_kernel(self):
        kernel_matrix = gaussian_kernel(self.X, self.X, self.sigma)
        self.assertTrue(kernel_matrix.shape == (50, 50))

    def test_hsic(self):
        hsic_value = hsic(self.X, self.Y, self.sigma)
        self.assertIsInstance(hsic_value, t.Tensor)

    def test_hsic_edge_cases(self):
        with self.assertRaises(RuntimeError):
            empty_tensor = t.tensor([])
            hsic(empty_tensor, self.Y, self.sigma)

    def test_different_shape(self):
        with self.assertRaises(RuntimeError):
            permutation_test(self.X, self.Y_diff_shape)

    def test_permutation_test_1d_input(self):
        with self.assertWarns(UserWarning):
            permutation_test(self.X_1d, self.X_1d)

    def test_permutation_test_more_than_2d_input(self):
        with self.assertRaises(RuntimeError):
            permutation_test(self.X_3d, self.X_3d)

    def test_different_number_of_permutations(self):
        for num_permutations in [100, 500, 1000]:
            result = permutation_test(self.X, self.Y, num_permutations=num_permutations)
            self.assertEqual(len(result["simulated_statistics"]), num_permutations)

    def test_output_structure(self):
        result = permutation_test(self.X, self.Y)
        self.assertIsInstance(result["hsic"], float)
        self.assertIsInstance(result["p_value"], float)
        self.assertIsInstance(result["simulated_statistics"], t.Tensor)


if __name__ == "__main__":
    unittest.main()
