import unittest
import random

import torch

from recognizers.dataset_generation.generate_datasets import (
    get_saved_language,
    generate_perturbed_string
)
from recognizers.hand_picked_languages.rayuela_util import to_rayuela_cfg

from rayuela.cfg.parser import Parser
from rayuela.base.symbol import Sym as RayuelaSym

class TestStringsum(unittest.TestCase):

    def test_negative_examples_test_cfg(self) -> None:
        dtype = torch.float16
        device = torch.device("cpu")
        num_samples = 200
        generator = random.Random(123)
        sampler = "../src/random-cfg-sampler.pt"
        length_range = (0, 40)
        language = get_saved_language(sampler, dtype, device)
        alphabet_size = language.alphabet_size()
        language = language.with_length_range(length_range)
        for _ in range(num_samples):
            s, _ = generate_perturbed_string(language, length_range, alphabet_size, generator)
            rayuela_cfg = to_rayuela_cfg(language.parent.grammar)
            parser = Parser(rayuela_cfg)
            parse = language.parent.uncached_label(s)
            rayuela_parse = parser.sum([RayuelaSym(sym) for sym in s]).value
            self.assertEqual(parse, rayuela_parse, "Must match the rayuela result")

    def test_negative_examples_marked_reversal(self) -> None:
        dtype = torch.float16
        device = torch.device("cpu")
        num_samples = 200
        generator = random.Random(123)
        sampler = "../src/random-cfg-sampler.pt"
        length_range = (0, 40)
        language = get_saved_language(sampler, dtype, device)
        alphabet_size = language.alphabet_size()
        language = language.with_length_range(length_range)
        for _ in range(num_samples):
            s, _ = generate_perturbed_string(language, length_range, alphabet_size, generator)
            rayuela_cfg = to_rayuela_cfg(language.parent.grammar)
            parser = Parser(rayuela_cfg)
            parse = language.parent.uncached_label(s)
            rayuela_parse = parser.sum([RayuelaSym(sym) for sym in s]).value
            self.assertEqual(parse, rayuela_parse, "Must match the rayuela result")

if __name__ == '__main__':
    unittest.main()