import unittest
import random

from recognizers.hand_picked_languages.majority import Majority
from recognizers.hand_picked_languages.parity import Parity
from recognizers.hand_picked_languages.first import First
from recognizers.hand_picked_languages.k_sparse_parity import KSparseParity
from recognizers.hand_picked_languages.k_sparse_majority import KSparseMajority

from recognizers.string_sampling.sample_dataset import (
    generate_example
)

class TestHandpicked(unittest.TestCase):

    # def test_first(self) -> None:
    #     generator = random.Random(12345)
    #     language = First()
    #     length = 50
    #     language = language.with_length_range(length)
    #     for _ in range(10000):
    #         s, label, _, _, _ = generate_example(
    #             language,
    #             length,
    #             2,
    #             False,
    #             0.5,
    #             True,
    #             False,
    #             False,
    #             False,
    #             generator,
    #             None
    #         )
    #         self.assertEqual(len(s), length)
    #         if label:
    #             self.assertEqual(s[0], 1)
    #         else:
    #             self.assertEqual(s[0], 0)

    # def test_parity(self) -> None:
    #     generator = random.Random(12345)
    #     language = Parity()
    #     length = 1
    #     language = language.with_length_range(length)
    #     for _ in range(10):
    #         s, label, _, _, _ = generate_example(
    #             language,
    #             length,
    #             2,
    #             False,
    #             0.5,
    #             True,
    #             False,
    #             False,
    #             False,
    #             generator,
    #             None
    #         )
    #         self.assertEqual(len(s), length)
    #         r = s.count(1) % 2
    #         if label:
    #             self.assertEqual(r, 1)
    #         else:
    #             self.assertEqual(r, 0)

    # def test_k_sparse_parity(self) -> None:
    #     generator = random.Random(12345)
    #     language = KSparseParity()
    #     length = 50
    #     language = language.with_length_range(length, 5, generator)
    #     k = language.k
    #     idx =  language.idx
    #     for _ in range(10):
    #         s, label, _, _, _ = generate_example(
    #             language,
    #             length,
    #             2,
    #             False,
    #             0.5,
    #             True,
    #             False,
    #             False,
    #             False,
    #             generator,
    #             None
    #         )
    #         self.assertEqual(len(s), length)
    #         s_k =  tuple([bit for i, bit in enumerate(s) if i in language.idx])
    #         r = s_k.count(1) % 2
    #         if label:
    #             self.assertEqual(r, 1)
    #         else:
    #             self.assertEqual(r, 0)

    # def test_majority(self) -> None:
    #     generator = random.Random(12345)
    #     language = Majority()
    #     length = 500
    #     language = language.with_length_range(length)
    #     for _ in range(10000):
    #         s, label, _, _, _ = generate_example(
    #             language,
    #             length,
    #             2,
    #             False,
    #             0.5,
    #             True,
    #             False,
    #             False,
    #             False,
    #             generator,
    #             None
    #         )
    #         self.assertEqual(len(s), length)
    #         c1 = s.count(1)
    #         c0 = s.count(0)
    #         if label:
    #             self.assertTrue(c1 > c0)
    #         else:
    #             self.assertTrue(c1 <= c0)

    def test_k_sparse_majority(self) -> None:
        generator = random.Random(12345)
        language = KSparseMajority()
        length = 50
        language = language.with_length_range(length, 5, generator)
        k = language.k
        idx =  language.idx
        for _ in range(10):
            s, label, _, _, _ = generate_example(
                language,
                length,
                2,
                False,
                0.5,
                True,
                False,
                False,
                False,
                generator,
                None
            )
            self.assertEqual(len(s), length)
            s_k =  tuple([bit for i, bit in enumerate(s) if i in language.idx])
            c1 = s_k.count(1)
            c0 = s_k.count(0)
            if label:
                self.assertTrue(c1 > c0)
            else:
                self.assertTrue(c1 <= c0)

if __name__ == '__main__':
    unittest.main()