import unittest
import random

import torch

from recognizers.string_sampling.sample_dataset import (
    get_automaton_language,
    generate_perturbed_string
)
from recognizers.hand_picked_languages.rayuela_util import to_rayuela_pda

from rayuela.pda.parser import Parser

class TestStringsum(unittest.TestCase):

    def test_positive_examples_unmarked_reversal(self) -> None:
        dtype = torch.float16
        device = torch.device("cpu")
        num_samples = 500
        generator = random.Random(123)
        sampler = "../src/unmarked-reversal-sampler.pt"
        length_range = (0, 80)
        language = get_automaton_language(sampler, dtype, device)
        language = language.with_length_range(length_range)
        for _ in range(num_samples):
            s, _ = language.sample(generator, False, False)
            accept = language.parent.automaton.accepts(s, dtype, device)
            self.assertTrue(accept, "Stringsum must be true for strings sampled from the automaton")

    def test_positive_examples_marked_reversal(self) -> None:
        dtype = torch.float16
        device = torch.device("cpu")
        num_samples = 500
        generator = random.Random(123)
        sampler = "../src/marked-reversal-sampler.pt"
        length_range = (0, 80)
        language = get_automaton_language(sampler, dtype, device)
        language = language.with_length_range(length_range)
        for _ in range(num_samples):
            s, _ = language.sample(generator, False, False)
            accept = language.parent.automaton.accepts(s, dtype, device)
            self.assertTrue(accept, "Stringsum must be true for strings sampled from the automaton")

    def test_positive_examples_random_pda(self) -> None:
        dtype = torch.float16
        device = torch.device("cpu")
        num_samples = 100
        generator = random.Random(123)
        sampler = "../src/random-pda-sampler.pt"
        length_range = (0, 50)
        language = get_automaton_language(sampler, dtype, device)
        language = language.with_length_range(length_range)
        for _ in range(num_samples):
            s, _ = language.sample(generator, False, False)
            accept = language.parent.automaton.accepts(s, dtype, device)
            self.assertTrue(accept, "Stringsum must be true for strings sampled from the automaton")

    def test_epsilon(self) -> None:
        dtype = torch.float16
        device = torch.device("cpu")
        sampler = "../src/unmarked-reversal-sampler.pt"
        training_length_range = (0, 80)
        language = get_automaton_language(sampler, dtype, device)
        language = language.with_length_range(training_length_range)
        accept = language.parent.automaton.accepts((), dtype, device)
        self.assertTrue(accept, "Automaton accepts epsilon")

    def test_negative_examples_unmarked_reversal(self) -> None:
        dtype = torch.float16
        device = torch.device("cpu")
        num_samples = 20
        generator = random.Random(123)
        sampler = "../src/unmarked-reversal-sampler.pt"
        length_range = (0, 80)
        language = get_automaton_language(sampler, dtype, device)
        alphabet_size = language.alphabet_size()
        language = language.with_length_range(length_range)
        for _ in range(num_samples):
            s, _ = generate_perturbed_string(language, length_range, alphabet_size, generator)
            rayuela_pda = to_rayuela_pda(language.parent.automaton)
            parser = Parser(rayuela_pda)
            parse = language.parent.automaton.accepts(s, dtype, device)
            rayuela_parse = parser.parse(s).value
            self.assertEqual(parse, rayuela_parse, "Must match the rayuela result")

    def test_negative_examples_marked_reversal(self) -> None:
        dtype = torch.float16
        device = torch.device("cpu")
        num_samples = 20
        generator = random.Random(123)
        sampler = "../src/marked-reversal-sampler.pt"
        length_range = (0, 80)
        language = get_automaton_language(sampler, dtype, device)
        alphabet_size = language.alphabet_size()
        language = language.with_length_range(length_range)
        for _ in range(num_samples):
            s, _ = generate_perturbed_string(language, length_range, alphabet_size, generator)
            rayuela_pda = to_rayuela_pda(language.parent.automaton)
            parser = Parser(rayuela_pda)
            parse = language.parent.automaton.accepts(s, dtype, device)
            rayuela_parse = parser.parse(s).value
            self.assertEqual(parse, rayuela_parse, "Must match the rayuela result")

    def test_negative_examples_random_pda(self) -> None:
        dtype = torch.float16
        device = torch.device("cpu")
        num_samples = 50
        generator = random.Random(123)
        sampler = "../src/random-pda-sampler.pt"
        length_range = (0, 50)
        language = get_automaton_language(sampler, dtype, device)
        alphabet_size = language.alphabet_size()
        language = language.with_length_range(length_range)
        for _ in range(num_samples):
            s, _ = generate_perturbed_string(language, length_range, alphabet_size, generator)
            rayuela_pda = to_rayuela_pda(language.parent.automaton)
            parser = Parser(rayuela_pda)
            parse = language.parent.automaton.accepts(s, dtype, device)
            rayuela_parse = parser.parse(s).value
            self.assertEqual(parse, rayuela_parse, "Must match the rayuela result")

if __name__ == '__main__':
    unittest.main()