# The script is used to generate datasets.

import numpy as np
np.random.seed(42)
import random
random.seed(42)
from prompt_lc import Prompt
with open("names.txt", "r") as f:
    NAMES = f.read().strip().split()

class Dataset_Generator:
    def __init__(self) -> None:
        self.name = 'detect_capital'

    def gen_data_from_len(self, length: int) -> dict:
        '''
        return datapoint of given length,
        datapoint is a dict of keys including `"question", "gt", ...`
        '''
        if self.name == "max_consecutive_ones":
            # example question: Given a binary array nums = [0,1,0,1,1,1], return the maximum number of consecutive 1's in the array.
            # example answer: 3
            ans = cnt = 0
            question = 'Given a binary array nums = ['
            rand = []
            if random.choices([1, 0]) == 0:
                rand = random.choices([1, 0], k=length-1)
                rand.append(random.choices([1, 0]))
                question += f'{rand[-1]}'
                for idx in range(length-1):
                    question += f', {rand[idx]}'
                    if rand[idx] == 1:
                        cnt += 1
                    else:
                        cnt = 0
                    ans = max(ans, cnt)
            else:
                Len = random.choice(range(length//2)) + (length+1)//2
                l = random.choice(range((length-1)//2))
                r = l + Len - 1
                rand.append(random.choice([1, 0]))
                question += f'{rand[-1]}'
                for idx in range(length-1):
                    if idx == l-1 or idx == r+1:
                        question += f', 0'
                        rand.append(0)
                    elif idx >= l and idx <= r:
                        question += f', 1'
                        rand.append(1)
                    else:
                        rand.append(random.choice([1, 0]))
                        question += f', {rand[-1]}'
                    ans = r - l
            question += '], return the maximum number of consecutive 1\'s in the array.'
            gt = str(ans)
            return {"question": question, 
                    "gt": gt,
                    "nums": rand}
        elif self.name == "hamming_distance":
            # example question: The Hamming distance between two integers is the number of positions at which the corresponding bits are different.\nGiven two integers x and y, return the Hamming distance between them.\nx = 1 and y = 4.
            # example answer: 2
            xBackup = x = random.randint(0, length)
            yBackup = y = random.randint(0, length)
            question = f'The Hamming distance between two integers is the number of positions at which the corresponding bits are different.\nGiven two integers x and y, return the Hamming distance between them.\n'
            question += f'x = {x} and y = {y}'
            gt = 0
            while x != 0 and y != 0:
                if x % 2 != y % 2:
                    gt += 1
                x //= 2
                y //= 2
            return {"question": question, 
                    "gt": gt,
                    "number": (xBackup, yBackup)}
        elif self.name == "license_key_formatting":
            # example question: You are given a license key represented as a string s that consists of only alphanumeric characters and dashes. The string is separated into n + 1 groups by n dashes. You are also given an integer k.\nWe want to reformat the string s such that each group contains exactly k characters, except for the first group, which could be shorter than k but still must contain at least one character. Furthermore, there must be a dash inserted between two groups, and you should convert all lowercase letters to uppercase.\nReturn the reformatted license key.\ns = "5F3Z-2e-9-w", k = 4
            # example answer: 5F3Z-2E9W
            import string
            rand = random.choices([0, 1], k = length)
            s = ''
            k = random.choice(range(1, sum(rand)))
            for i in range(length):
                if rand[i]:
                    s += random.choice(string.ascii_letters)
                else:
                    s += '-'
            question = f'You are given a license key represented as a string s that consists of only alphanumeric characters and dashes. The string is separated into n + 1 groups by n dashes. You are also given an integer k.\nWe want to reformat the string s such that each group contains exactly k characters, except for the first group, which could be shorter than k but still must contain at least one character. Furthermore, there must be a dash inserted between two groups, and you should convert all lowercase letters to uppercase.\nReturn the reformatted license key.\n'
            question += f's = "{s}", k = {k}'
            gt = ''
            idx = 0
            while idx < length:
                if s[idx] != '-':
                    gt += s[idx]
                idx += 1
            idx = 1
            while idx < len(gt):
                if (len(gt) - idx) % k == 0:
                    gt = gt[:idx] + '-' + gt[idx:]
                    idx += 1
                idx += 1
            return {"question": question, 
                    "gt": gt.upper(),
                    "s": s,
                    "k": k}
        elif self.name == "keyboard_row":
            # example question: Given an array of strings words, return the words that can be typed using letters of the alphabet on only one row of American keyboard.\nIn the American keyboard:\nthe first row consists of the characters "qwertyuiop",\nthe second row consists of the characters "asdfghjkl", and\nthe third row consists of the characters "zxcvbnm".\nwords = ["Hello","Alaska","Dad","Peace"]
            # example answer: ["Alaska","Dad"]
            import string
            rand = random.choices([0, 1], k = length)
            words = []
            gt = []
            for i in range(length):
                if rand[i]:
                    words.append(''.join(random.choices(string.ascii_letters, k = random.randint(0, length-3))))
                    words[-1] += random.choice('qwertyuiopQWERTYUIOP') + random.choice('asdfghjklASDFGHJKL') + random.choice('zxcvbnmZXCVBNM')
                else:
                    tempRand = random.choice([1, 2, 3])
                    if tempRand == 1:
                        words.append(''.join(random.choices('qwertyuiopQWERTYUIOP', k = random.randint(1, length))))
                    elif tempRand == 2:
                        words.append(''.join(random.choices('asdfghjklASDFGHJKL', k = random.randint(1, length))))
                    elif tempRand == 3:
                        words.append(''.join(random.choices('zxcvbnmZXCVBNM', k = random.randint(1, length))))
                    gt.append(words[-1])
            question = f'Given an array of strings words, return the words that can be typed using letters of the alphabet on only one row of American keyboard.\nIn the American keyboard:\nthe first row consists of the characters "qwertyuiop",\nthe second row consists of the characters "asdfghjkl", and\nthe third row consists of the characters "zxcvbnm".\n'
            question += f'words = {str(words)}'
            return {"question": question, 
                    "gt": gt,
                    "words": words}
        elif self.name == "detect_capital":
            # example question: We define the usage of capitals in a word to be right when one of the following cases holds:\nAll letters in this word are capitals, like "USA".\nAll letters in this word are not capitals, like "leetcode".\nOnly the first letter in this word is capital, like "Google".\nGiven a string word, return true if the usage of capitals in it is right.\nword = "USA"
            # example answer: True
            word = ''
            rand = random.choice([1, 0])
            if rand == 1:
                tempRand = random.choice([1, 2, 3])
                if tempRand == 1:
                    word = word.join(random.choices([chr(x) for x in range(ord('A'), ord('Z')+1)], k = length))
                else:
                    word = word.join(random.choices([chr(x) for x in range(ord('a'), ord('z')+1)], k = length))
                if tempRand == 3:
                    word = random.choice([chr(x) for x in range(ord('A'), ord('Z')+1)]) + word[1:]
            else:
                word = word.join(random.choices([chr(x) for x in range(ord('A'), ord('Z')+1)]+[chr(x) for x in range(ord('a'), ord('z')+1)], k = length))
            question = f'We define the usage of capitals in a word to be right when one of the following cases holds:\nAll letters in this word are capitals, like "USA".\nAll letters in this word are not capitals, like "leetcode".\nOnly the first letter in this word is capital, like "Google".\nGiven a string word, return true if the usage of capitals in it is right.\n'
            question += f'word = {word}'
            A = 0
            gt = False
            for idx in range(len(word)):
                if word[idx].isupper():
                    A += 1
            if A == len(word) or A == 0 or (A == 1 and word[0].isupper()):
                gt = True
            return {"question": question, 
                    "gt": gt,
                    "word": word}
        elif self.name == "design_hashset":
            # example question: Design a hashset with 3 functions: add, remove, contains \noperation =  ["MyHashSet", "add", "add", "contains", "contains", "add", "contains", "remove", "contains"]\nkeys = [[], [1], [2], [1], [3], [2], [2], [2], [2]]\n
            # example answer: [null, null, null, true, false, null, true, null, false]
            oper = random.choices(['add', 'remove', 'contains'], k = length)
            keys = random.choices(range(length//3), k = length)
            question = f'Design a hashset with 3 functions: add, remove, contains \n'
            question += f'operation = {str(oper)}\nkeys = {str(keys)}\n'
            ans = []
            hashset = set()
            for i in range(length):
                if oper[i] == 'add':
                    hashset.add(keys[i])
                    ans.append(None)
                if oper[i] == 'remove':
                    hashset.discard(keys[i])
                    ans.append(None)
                if oper[i] == 'contains':
                    ans.append(keys[i] in hashset)
            return {"question": question, 
                    "gt": ans,
                    "oper": oper,
                    "keys": keys}
        elif self.name == "jewels_and_stones":
            # example question: Given the strings a = 'aA' and b = 'aAAbbbb', count how many characters in a have happened in b
            # example answer: 3
            import string
            characters = list(string.ascii_letters)
            random.shuffle(characters)
            a = ''.join(characters[:random.randint(2, min(50, length))])
            b = ''.join(list(random.choices(string.ascii_letters, k = length)))
            question = f'Design a hashset with 3 functions: add, remove, contains \n'
            question += f'a = {a}\nb = {b}\n'
            gt = [i in a for i in b].count(True)
            return {"question": question,
                    "gt": gt,
                    "a": a,
                    "b": b}
        elif self.name == "rotate_string":
            # example question: Given two strings s = 'abcde' and goal = 'bcdea', return true if and only if s can become goal after some number of shifts on s.\nA shift on s consists of moving the leftmost character of s to the rightmost position.\n
            # example answer: True
            import string
            s = ''.join(random.choices(string.ascii_letters, k = length))
            goal = ''
            if random.choice([0, 1]) == 1:
                mid = random.randint(0, length)
                goal += s[mid:] + s[:mid]
            else:
                temp = [ch for ch in s]
                random.shuffle(temp)
                goal += ''.join(temp)
            question = f'Given two strings s and goal, return true if and only if s can become goal after some number of shifts on s.\nA shift on s consists of moving the leftmost character of s to the rightmost position.\n'
            question += f's = {s}\ngoal = {goal}\n'
            gt = False
            i = 0
            while i < len(s):
                if goal == s[:i] + s[i:]:
                    gt = True
                i += 1
            return {"question": question, 
                    "gt": gt,
                    "s": s,
                    "goal": goal}
        elif self.name == "most_common_word":
            # example question: Given a string paragraph and a string array of the banned words banned, return the most frequent word that is not banned.\nThe words in paragraph are case-insensitive and the answer should be returned in lowercase.\n
            # example answer: 'ball'
            import string
            words = ["Alice", "Bob", "Carol", "Dave", "Eve", "Francis", "Garce", "Hans", "Isabella", "Jason", "Kate", "Louis", "Margaret", "Nathan", "Olivia", "Paul", "Queen", "Richard", "Susan", "Thomas", "Uma", "Vivian", "Winnie", "Xander", "Yasmine", "Zach"]
            paragraph = ''.join(random.choices(words, k = length)) + ' '
            banned = list(set(random.choices(words, k = random.randint(1, length//2)))) # remove repeated element
            cnt = map()
            for word in words:
                cnt[word] = 0
            for word in words:
                cnt[word] += 1
            question = f'Given two strings s and goal, return true if and only if s can become goal after some number of shifts on s.\nA shift on s consists of moving the leftmost character of s to the rightmost position.\n'
            question += f's = {s}\ngoal = {goal}\n'
            return {"question": question, 
                    "gt": gt}

    def rfft_IO(self, data: dict) -> dict:
        '''
        return rfft input-output of given data
        '''
        instruction = "Follow the given rule to solve the question.\nrule:"
        if self.name == "coin_flip":
            P = Prompt("coin_flip")
            rule = P.rule
            input = instruction + rule + "\n\nQ: " + data["question"]
            # rfft output
            flips = data["flips"]
            heads_up = True
            output = P.initialize.format(flips)
            for flip in flips:
                if flip:
                    output += P.one_iteration_2_1_flip.format(heads_up, not heads_up)
                    heads_up = not heads_up
                else:
                    output += P.one_iteration_2_1_no_flip.format(heads_up)
            answer = "Yes" if heads_up else "No"
            output += P.return_result.format(heads_up, "Yes" if heads_up else "No")
            return {"input": input,
                    "output": output,
                    "answer": answer}
        elif self.name == "max_consecutive_ones":
            P = Prompt("max_consecutive_ones")
            rule = P.rule
            input = instruction + rule + "\n\nQ: " + data["question"]
            # rfft output
            nums = data["nums"]
            ans = cnt = 0
            output = P.initialize.format(nums)
            len_nums = len(nums)
            while nums != []:
                num = nums.pop(0)
                if num == 1:
                    cnt += 1
                    output += P.one_iteration_2_1_count.format([num] + nums, num, nums, cnt, ans, max(ans, cnt))
                else:
                    cnt = 0
                    output += P.one_iteration_2_1_to0.format([num] + nums, num, nums, cnt, ans, max(ans, cnt))
                ans = max(ans, cnt)
            output += P.one_iteration_2_1_break.format()
            output += P.return_result.format(ans, ans)
            return {"input": input,
                    "output": output,
                    "answer": ans}
        elif self.name == "hamming_distance":
            P = Prompt("hamming_distance")
            rule = P.rule
            input = instruction + rule + "\n\nQ: " + data["question"]
            # rfft output
            (x, y) = data["number"]
            ans = 0
            output = P.initialize.format(x, y)
            while x != 0 or y != 0:
                if x % 2 != y % 2:
                    ans += 1
                    output += P.one_iteration_2_1_different.format(x, y, x % 2, y % 2, ans, x // 2, y // 2)
                else:
                    output += P.one_iteration_2_1_same.format(x, y, x % 2, y % 2, x // 2, y // 2)
                x //= 2
                y //= 2
            output += P.one_iteration_2_1_break.format()
            output += P.return_result.format(ans, ans)
            return {"input": input,
                    "output": output,
                    "answer": ans}
        elif self.name == "license_key_formatting":
            P = Prompt("license_key_formatting")
            rule = P.rule
            input = instruction + rule + "\n\nQ: " + data["question"]
            # rfft output
            s = data["s"]
            k = data["k"]
            ans = ''
            cnt = 0
            output = P.initialize.format(s, k)
            while s != '':
                if s[-1] != '-':
                    output += P.one_iteration_2_1_in.format(s, s[-1])
                    if cnt % k == 0 and ans != '':
                        output += P.one_iteration_2_2_in.format(cnt, k, ans, '-' + ans, s[-1] + '-' + ans, s[-1:], cnt + 1)
                        ans = '-' + ans
                    else:
                        output += P.one_iteration_2_2_out.format(cnt, k, cnt % k, ans, s[-1] + ans, s[-1:], cnt + 1)
                    ans = s[-1] + ans
                    cnt += 1
                else:
                    output += P.one_iteration_2_1_out.format(s, s[:-1])
                s = s[:-1]
            output += P.one_iteration_2_break.format()
            output += P.return_result.format(ans, ans.upper(), ans.upper())
            return {"input": input,
                    "output": output,
                    "answer": ans.upper()}
        elif self.name == "keyboard_row":
            P = Prompt("keyboard_row")
            rule = P.rule
            input = instruction + rule + "\n\nQ: " + data["question"]
            # rfft output
            words = data["words"]
            ans = []
            k1 = 'qwertyuiopQWERTYUIOP'
            k2 = 'asdfghjklASDFGHJKL'
            k3 = 'zxcvbnmZXCVBNM'
            output = P.initialize.format(words)
            while words != []:
                word = words.pop(0)
                lenWord = len(word)
                wordBackup = word
                cnt1 = cnt2 = cnt3 = 0
                flag = ''
                output += P.one_iteration_2_1_head.format([word] + words, word, lenWord, wordBackup)
                while word != '':
                    if word[0] in k1:
                        cnt1 += 1
                        output += P.one_iteration_2_2_if1_true.format(word, word[0], cnt1)
                        flag = 'cnt1'
                    else:
                        output += P.one_iteration_2_2_if1_false.format(word, word[0])
                    if word[0] in k2:
                        cnt2 += 1
                        output += P.one_iteration_2_2_if2_true.format(word, word[0], cnt2)
                        flag = 'cnt2'
                    else:
                        output += P.one_iteration_2_2_if2_false.format(word, word[0])
                    if word[0] in k3:
                        cnt3 += 1
                        output += P.one_iteration_2_2_if3_true.format(word, word[0], cnt3, word[1:])
                        flag = 'cnt3'
                    else:
                        output += P.one_iteration_2_2_if3_false.format(word, word[0], word[1:])
                    word = word[1:]
                if cnt1 == lenWord or cnt2 == lenWord or cnt3 == lenWord:
                    ans.append(wordBackup)
                    output += P.one_iteration_2_1_if_true.format(cnt1, cnt2, cnt3, lenWord, flag, ans)
                else:
                    output += P.one_iteration_2_1_if_false.format(cnt1, cnt2, cnt3, lenWord)
            output += P.return_result.format(ans, ans)
            return {"input": input,
                    "output": output,
                    "answer": ans}
        elif self.name == "detect_capital":
            P = Prompt("detect_capital")
            rule = P.rule
            input = instruction + rule + "\n\nQ: " + data["question"]
            # rfft output
            word = data["word"]
            cnt = 0
            firstLetter = word[0]
            lenWord = len(word)
            output = P.initialize.format(word, firstLetter, lenWord)
            while word != '':
                if word[0].isupper():
                    cnt += 1
                    output += P.one_iteration_2_1_in.format(word, word[0], cnt, word[1:])
                    word = word[1:]
                else:
                    output += P.one_iteration_2_1_out.format(word, word[0], word[1:])
                    word = word[1:]
            output += P.one_iteration_2_1_break.format()
            ans = False
            if cnt == 0 or cnt == lenWord or (cnt == 1 and firstLetter.isupper()):
                ans = True
                output += P.return_true.format(cnt, lenWord, firstLetter, firstLetter.isupper())
            else:
                output += P.return_false.format(cnt, lenWord, firstLetter, firstLetter.isupper())
            return {"input": input,
                    "output": output,
                    "answer": ans}
        elif self.name == "design_hashset":
            P = Prompt("design_hashset")
            rule = P.rule
            input = instruction + rule + "\n\nQ: " + data["question"]
            # rfft output
            oper = data["oper"]
            keys = data["keys"]
            ans = []
            hashset = set()
            output = P.initialize.format(str(oper), str(keys))
            while oper != []:
                if oper[0] == 'add': 
                    hashset.add(keys.pop(0))
                    ans.append(None)
                    output += P.one_iteration_2_1_if1_true.format(oper, hashset, ans, oper[1:], keys)
                if oper[0] == 'remove':
                    hashset.discard(keys.pop(0))
                    ans.append(None)
                    output += P.one_iteration_2_1_if2_true.format(oper, hashset, ans, oper[1:], keys)
                if oper[0] == 'contains':
                    ans.append(keys[0] in hashset)
                    output += P.one_iteration_2_1_if3_true.format(oper, keys[0], hashset, keys[0] in hashset, ans, oper[1:], keys[1:])
                    keys.pop(0)
                oper.pop(0)
            output += P.one_iteration_2_1_break.format()
            output += P.return_result.format(ans, ans)
            return {"input": input,
                    "output": output,
                    "answer": ans}
        elif self.name == "jewels_and_stones":
            P = Prompt("jewels_and_stones")
            rule = P.rule
            input = instruction + rule + "\n\nQ: " + data["question"]
            # rfft output
            a = data["a"]
            b = data["b"]
            output = P.initialize.format(a, b)
            ans = [ch in a for ch in b].count(True)
            output += P.return_result.format([ch for ch in b], [ch in a for ch in b], ans, ans)
            return {"input": input,
                    "output": output,
                    "answer": ans}
        elif self.name == "rotate_string":
            P = Prompt("rotate_string")
            rule = P.rule
            input = instruction + rule + "\n\nQ: " + data["question"]
            # rfft output
            s = data["s"]
            goal = data["goal"]
            ans = False
            sBackup = s[-1] + s[:-1]
            output = P.initialize.format(s, goal)
            while s != sBackup:
                if s[1:] + s[0] == goal:
                    ans = True
                    output += P.one_iteration_2_1_if_true.format(s, sBackup, s[1:] + s[0], ans, s[1:] + s[0])
                else:
                    output += P.one_iteration_2_1_if_false.format(s, sBackup, s[1:] + s[0], goal, s[1:] + s[0])
                s = s[1:] + s[0]
            output += P.one_iteration_2_1_break.format(sBackup)
            output += P.return_result.format(ans, ans)
            return {"input": input,
                    "output": output,
                    "answer": ans}
