# script: 
# python3 dataset_generation/generate.py --n 1000

# Finite state grammars obey the following rules:
    # 6 nodes, including in node S0, out node S0', and four nodes S1-S4
    # S0 is always connected to both S1 and S3, S2 and S4 are always connected to S0'
    # S1 is always connected to S2 (to avoid isomorphisms and the null case)
    # the remaining middle edges (S1 -> S3, S1 -> S4, S2 -> S1, S2 -> S3, S2 -> S4, S3 -> S1, S3 -> S2, S3 -> S4, S4 -> S1, S4 -> S2, S4 -> S3) can either exist or not, for a total of 2^11 combinations
    # each of S1-S4 can have self-loops (e.g. S1 -> S1), for a total of 2^4 combinations
    # Paths are limited to maximum 8 letters
    # Grammar should be able to generate at least 37 unique strings with length <= 8
    # Letters on edges are each a random member of the 26 alphabet letters, for a total of 26^8 combinations
    # we can sample 1000 of these problems, and for each problem sample 37 strings like in Fallshore and Schooler (1993) (or more depending on how large we want the dataset to be), and test the LLM on the same task (22 distactors, 22 corrects).
    # though the design decision are made with Fallshore and Schooler (1993), the original paper that developed the original FSG was by Reber and Lewis (1977)

import numpy as np
from tqdm import tqdm
import argparse, pickle, json, os, time, copy, random


class FSG:
    def __init__(self, random_edges=None):
        if random_edges is None:
            self.random_edges = np.random.randint(2, size=(4,4)) # the middle edges and self loops can either exist or not
            self.random_edges[0][1] = 1 # S1 is connected to S2 (to avoid isomorphisms and the null case)
        else:
            self.random_edges = random_edges
        
        self.build_full_fsg()
        self.construct_isomorphisms()

    def build_full_fsg(self):
        # Each finite state grammar is represented as an adjacency matrix M. 
        # If M[1][2] == 1, then there is an edge going from node 1 to node 2. 
        # Note that edges are unidirectional, so M[1][2] does not necessarily equal M[2][1]. 
        M = np.zeros((6, 6))

        # source S0 is connected to S1 and S3
        M[0][1] = 1
        M[0][3] = 1

        # S2 and S4 are connected to sink S0'
        M[2][5] = 1
        M[4][5] = 1

        for i in range(1, 5): 
            for j in range(1, 5):
                M[i][j] = self.random_edges[i-1][j-1]
        
        self.adjacency_matrix = M

    def construct_isomorphisms(self):
        self.isomorphisms = []
        self.construct_isomorphism([])
        self.construct_isomorphism([(1, 3)])
        self.construct_isomorphism([(2, 4)])
        self.construct_isomorphism([(1, 3), (2, 4)])
    
    def construct_isomorphism(self, L):
        new = np.copy(self.random_edges)
        for (i, j) in L:
            i -= 1
            j -= 1
            new[[i, j]] = new[[j, i]]
            new[:, [i, j]] = new[:, [j, i]]
        
        # restrict to only isomorphisms that also have S1 connected to S2, otherwise no need to check for equality
        if new[0][1] == 1:
            self.isomorphisms.append(new)

    def fsg_isomorphic(self, other):
        for isomorphism in self.isomorphisms:
            if np.array_equal(other.random_edges, isomorphism):
                return True
        
        return False

    def generate_strings(self, max_length=8):
        def traverse(node, current_string):
            # If we have reached the sink node, add the current string
            if node == 5:  # S5 is the sink state
                strings.append(current_string)
                return
            
            # If we have reached the maximum length, stop further recursion
            if len(current_string) >= max_length:
                return
            
            # Traverse all possible transitions from the current node
            for next_node in range(6):
                if self.adjacency_matrix[node][next_node] == 1:  # There is an edge
                    traverse(next_node, current_string + str(next_node))

        # List to store all generated strings
        strings = []
        
        # Start traversal from the source node S0
        traverse(0, "0")  # Start with the initial state as the first character in the string
        
        self.strings = strings
        return strings

    def translate_string_to_letter(self, string):
        letter_string = ""
        for i in range(len(string) - 1):
            letter_string += self.edge_to_letter[(int(string[i]), int(string[i+1]))]
        return letter_string

    def sample_strings(self, n):
        # sample n strings from the FSG, with edges between numbers, and these edges are mapped to letters
        number_strings = np.random.choice(self.strings, n, replace=False)
        letter_strings = [self.translate_string_to_letter(string) for string in number_strings]
        return letter_strings

def generate_finite_state_grammars(n, num_tries=50):
    fsgs = []
    for i in tqdm(range(n)):
        # generate a finite state grammar, check if it's unique, if not, generate another one, up until n unique fsgs are generated

        for j in range(num_tries):
            fsg = FSG()

            has_equal = False
            for other in fsgs:
                if fsg.fsg_isomorphic(other):
                    has_equal = True
                    break
            
            if not has_equal: # great, first step passed
                strings = fsg.generate_strings()
                if len(strings) >= 37:
                    fsgs.append(fsg) # great, second step passed
                    break

            elif j == num_tries - 1:
                print("Could not generate unique FSG, try decreasing n")
                return fsgs
    
    return fsgs

def test_FSG():
    def test_FSG_class():
        for i in range(20):
            fsg = FSG()

            # check that the adjacency matrix is a larger version of the random_edges matrix
            assert np.array_equal(fsg.adjacency_matrix[1:5, 1:5], fsg.random_edges)

            # check that itself is an isomorphism
            assert fsg.fsg_isomorphic(fsg)

            # check that the isomorphisms are correct
            for isomorphism in fsg.isomorphisms:
                assert fsg.fsg_isomorphic(FSG(isomorphism))

        # check that the (1,3) isomorphisms are correct
        fsg1 = FSG(np.array([[1, 1, 0, 0], 
                            [0, 1, 1, 0], 
                            [0, 1, 1, 0], 
                            [0, 0, 0, 1]]))
        fsg2 = FSG(np.array([[1, 1, 0, 0], 
                            [1, 1, 0, 0], 
                            [0, 1, 1, 0], 
                            [0, 0, 0, 1]]))
        assert fsg1.fsg_isomorphic(fsg2)

        fsg1 = FSG(np.array([[0, 1, 0, 0], 
                            [0, 0, 1, 1], 
                            [0, 1, 0, 0], 
                            [1, 0, 0, 1]]))
        fsg2 = FSG(np.array([[0, 1, 0, 0], 
                            [1, 0, 0, 1], 
                            [0, 1, 0, 0], 
                            [0, 0, 1, 1]]))
        assert fsg1.fsg_isomorphic(fsg2)

        # check that the (1,3), (2,4) isomorphisms are correct
        fsg1 = FSG(np.array([[1, 1, 0, 0], 
                            [0, 1, 1, 0], 
                            [0, 1, 0, 1], 
                            [1, 0, 0, 0]]))
        fsg2 = FSG(np.array([[0, 1, 0, 1], 
                            [0, 0, 1, 0], 
                            [0, 0, 1, 1], 
                            [1, 0, 0, 1]]))
        assert fsg1.fsg_isomorphic(fsg2)

    # test_FSG_class()

    def test_generate_strings():
        fsg = FSG()
        print(fsg.adjacency_matrix)
        strings = fsg.generate_strings()
        print(strings)

        # check that all strings are of length <= 8
        assert all(len(string) <= 8 for string in strings)
    
    test_generate_strings()

def assign_letters(fsgs):
    for fsg in fsgs:
        # construct a mapping from each edge of the FSG to a random capital letter
        edge_to_letter = {}
        for i in range(6):
            for j in range(6):
                if fsg.adjacency_matrix[i][j] == 1:
                    edge_to_letter[(i, j)] = chr(np.random.randint(26) + 65)
        
        fsg.edge_to_letter = edge_to_letter

def construct_fake_members(fsg, sampled_strings):
    # construct fake cases
    strings_base = copy.copy(fsg.strings)
    random.shuffle(strings_base)

    available_letters = set(fsg.edge_to_letter.values())
    
    fake_members = []
    for string_base in strings_base:
        string_base = fsg.translate_string_to_letter(string_base)

        for i in range(1000):
            # perturb with a letter in fsg.edge_to_letter
            perturb_index = np.random.randint(len(string_base))
            perturb_letter = np.random.choice(list(available_letters))
            fake_string = string_base[:perturb_index] + perturb_letter + string_base[perturb_index+1:]

            # check if it is not in the valid strings and not already a fake member
            if (fake_string not in strings_base) and (fake_string not in fake_members):
                fake_members.append(fake_string)
                break
            elif i == 999:
                print("Could not find a fake member, moving on to next string candidate")

        if len(fake_members) >= 22:
            break
    
    if len(fake_members) < 22:
        raise ValueError("Could not generate enough fake members")
    
    return fake_members


if __name__ == "__main__":
    start_time = time.time()

    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="dataset")
    parser.add_argument("--n", type=int, default=1000)
    parser.add_argument("--sample_strings_per_fsg", type=int, default=37)
    args = parser.parse_args()

    fsgs = generate_finite_state_grammars(args.n) # 2000 runs in around 12 seconds

    assign_letters(fsgs)

    dataset = []
    for fsg in fsgs:
        sampled_strings = fsg.sample_strings(args.sample_strings_per_fsg)

        # split into train and test
        train = sampled_strings[:15]
        test = sampled_strings[15:]

        # construct fake cases
        fakes = construct_fake_members(fsg, sampled_strings)

        d = {"fsg": fsg, "sampled_strings": sampled_strings, "train": train, "test": test, "fakes": fakes}
        dataset.append(d)
    

    # save as a pickle file
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    with open(f"{args.output_dir}/dataset.pkl", "wb") as f:
        pickle.dump(dataset, f)
    print(f"Dataset pickle saved in {args.output_dir}")

    # save as a json file
    dataset_json = []
    for d in dataset:
        fsg = d["fsg"]
        strings = fsg.strings
        adjacency_matrix = fsg.adjacency_matrix.tolist()
        edge_to_letter = {str(k): v for k, v in fsg.edge_to_letter.items()}
        sampled_strings = d["sampled_strings"]
        train = d["train"]
        test = d["test"]
        fakes = d["fakes"]

        dataset_json.append({"strings": strings, "adjacency_matrix": adjacency_matrix, "edge_to_letter": edge_to_letter, "sampled_strings": sampled_strings, "train": train, "test": test, "fakes": fakes})

    with open(f"{args.output_dir}/dataset.json", "w") as f:
        json.dump(dataset_json, f)

    print(f"Dataset json saved in {args.output_dir}")

    print(f"Time taken: {time.time() - start_time:.2f} seconds")

    # test_FSG()