
import numpy as np
import random
from prompt import Prompt
from copy import deepcopy

class Dataset_Generator:
    def __init__(self) -> None:
        self.id = 1417  
        self.name =  'reformat_string'  
        self.description = "reformat the string so that no two adjacent characters have the same type."  
        self.url = 'https://leetcode.com/problems/reformat-the-string/description/'
         
        self.rule = '''
def reformat_string(s):
    # Initialize variables
    alphas, digits, n = [], [], len(s)
    # First loop
    for i, c in enumerate(s):
        if c.isdigit():
            digits.append(c)
        else:
            alphas.append(c)
        # Judge
        if abs(len(digits) - len(alphas)) + i > n:
            ans = ''
            # Return
            return ans
    ans, last_alpha = [], False 
    # Second loop
    while alphas or digits:
        if not last_alpha and len(alphas) >= len(digits):
            m, last_alpha = (alphas, True)
        else:
            m, last_alpha = (digits, False)
        ans.append(m.pop())
    ans = "".join(ans)
    # Return
    return ans'''
        self.initialize = '''
1. Initialize variables
```
alphas, digits, n = [], [], len(s)
```
alphas = []
digits = []
s = {}
n = len(s) = {}
'''
        self.loop_1 = '''
2. First loop
```
for i, c in enumerate(s):
```
s = {}
i, c = {}, {}
'''
        self.judge_1_1 = '''
```
if c.isdigit():
```
c = {}
c is digit
enter
```
digits.append(c)
``` 
digits = {}
digits = digits + ['{}'] = {} + ['{}'] = {}'''
        self.judge_1_2 = '''
```
if c.isdigit():
```
c = {}
c is not digit
do not enter
```
else:
alphas.append(c)
```
alphas = {}
alphas = alphas + [{}] = {} + [{}] = {}'''
        self.judge_2_1 = '''
```
if abs(len(digits) - len(alphas)) + i > n:
```
digits = {}
alphas = {}
abs(len(digits) - len(alphas)) + i = abs({} - {}) + {} = {} + {} = {} > n = {}
enter
```
ans = ''
return ans
```
3.Return
ans = ''
So the answer is ''.'''
        self.judge_2_2 = '''
```
if abs(len(digits) - len(alphas)) + i > n:
```
digits = {}
alphas = {}
abs(len(digits) - len(alphas)) + i = abs({} - {}) + {} = {} + {} = {} <= n = {}
do not enter
```'''
        self.loop_2 = '''
```
ans, last_alpha = [], False 
```
ans = []
last_alpha = False
3. Second loop
```
while alphas or digits:
```
alphas or digits = {} or {} = {}
enter
```'''
        self.judge_3_1 = '''
```
if not last_alpha and len(alphas) >= len(digits):
```
last_alpha = {}
not last_alpha = {}
alphas = {}
digits = {}
len(alphas) = {}
len(digits) = {}
len(alphas) >= len(digits)
not last_alpha and len(alphas) >= len(digits) = {} and {} = {}
enter
```
m, last_alpha = (alphas, True)
```
m = alphas = {}
last_alpha = True
```
ans.append(m.pop())
```
ans = {}
ans = ans + [m.pop()] = {} + [{}] = {}'''
        self.judge_3_2 = '''
```
if not last_alpha and len(alphas) >= len(digits):
```
last_alpha = {}
not last_alpha = {}
alphas = {}
digits = {}
not last_alpha and len(alphas) >= len(digits) = {} and {} = {}
do not enter
```
else:
m, last_alpha = (digits, False)
```
m = digits = {}
last_alpha = False
```
ans.append(m.pop())
```
ans = {}
ans = ans + [m.pop()] = {} + [{}] = {}'''
        self.complete_loop = '''
4. Complete loop
```
while alphas or digits:
```
alphas or digits = {} or {} = {}
do not enter
```
ans = "".join(ans)
```
ans = "".join(ans) = "".join({}) = "{}"
5. Return
ans = '{}'
So the answer is '{}'.'''

    def gen_data_from_len(self, length: int) -> dict:
        s = random.choices([chr(i) for i in range(ord('a'), ord('z')+1)]+[str(i) for i in range(0, 10)], k=length)
        s = "".join(s)
        question = f"A string '{s}', find a permutation of the string where no letter is followed by another letter and no digit is followed by another digit. That is, no two adjacent characters have the same type."
        alphas, digits, n = [], [], len(s)
        for i, c in enumerate(s):
            if c.isdigit():
                digits.append(c)
            else:
                alphas.append(c)
            if abs(len(digits) - len(alphas)) + i > n:
                ans = ''
                return {"question": question,
                        "gt": ans,
                        "s": s}
        ans, last_alpha = [], False 
        while alphas or digits:
            if not last_alpha and len(alphas) >= len(digits):
                m, last_alpha = (alphas, True)
            else:
                m, last_alpha = (digits, False)
            ans.append(m.pop())
        ans = "".join(ans)
        return {"question": question,
                "gt": ans,
                "s": s}
        '''
        return datapoint of given length
        
        return {...}
        '''
            
        
    def rfft_IO(self, data: dict) -> dict:
        instruction = "Follow the given rule to solve the question.\nrule:"
        '''
        data: a datapoint from gen_data_from_len
        return rfft input-output of given data
        
        return {"input": rfft_input,
                "output": rfft_output,
                "answer": ground_truth_answer}
        '''
        rule = self.rule
        input = instruction + rule + "\n\nQ: " + data["question"]
        # rfft output
        s = data["s"]
        output = self.initialize.format(s, len(s))
        alphas, digits, n = [], [], len(s)
        for i, c in enumerate(s):
            output += self.loop_1.format(s, i, c)
            if c.isdigit():
                output += self.judge_1_1.format(c, digits, c, digits, c, digits+[c])
                digits.append(c)
            else:
                output += self.judge_1_2.format(c, alphas, c, alphas, c, alphas+[c])
                alphas.append(c)
            if abs(len(digits) - len(alphas)) + i > n:
                output += self.judge_2_1.format(digits, alphas, len(digits), len(alphas), i, abs(len(digits) - len(alphas)), i, abs(len(digits) - len(alphas)) + i, n)
                ans = ''
                return {"input": input,
                        "output": output,
                        "answer": ans}
            else:
                output += self.judge_2_2.format(digits, alphas, len(digits), len(alphas), i, abs(len(digits) - len(alphas)), i, abs(len(digits) - len(alphas)) + i, n)
        ans, last_alpha = [], False 
        while alphas or digits:
            output += self.loop_2.format(alphas, digits,  alphas or digits)
            if not last_alpha and len(alphas) >= len(digits):
                output += self.judge_3_1.format(last_alpha, not last_alpha, alphas, digits, len(alphas), len(digits), not last_alpha, len(alphas) >= len(digits),not last_alpha and len(alphas) >= len(digits), alphas, ans, ans, alphas[-1], ans + [alphas[-1]])
                m, last_alpha = (alphas, True)
            else:
                output += self.judge_3_2.format(last_alpha, not last_alpha, alphas, digits,not last_alpha, len(alphas)>=len(digits), not last_alpha and len(alphas) >= len(digits), digits, ans, ans, digits[-1], ans + [digits[-1]])
                m, last_alpha = (digits, False)
            ans.append(m.pop())
            if not (alphas or digits):
                output += self.complete_loop.format(alphas, digits, alphas or digits, ans, ''.join(ans), ''.join(ans), ''.join(ans))
        ans = "".join(ans)
        return {"input": input,
                "output": output,
                "answer": ans}

Generator = Dataset_Generator()
length = 10
data = Generator.gen_data_from_len(length)
print(data)

sample = Generator.rfft_IO(data)
print(sample["input"], sample["output"])