import numpy as np
import random
from copy import deepcopy

class Dataset_Generator:
    def __init__(self) -> None:
        self.id = 125
        self.name = "isselfalindrome"
        self.description = '''Given a string s, return true if it is a palindrome, or false otherwise.'''
        self.url = '''https://leetcode.com/problems/valid-palindrome/description/'''
        self.rule = '''
def isselfalindrome(self, s: str) -> bool:
    c_list = []
    while s:
        c = s[-1]
        s = s[:-1]
        if c.isalnum():
            c_list.append(c.lower())
    while len(c_list) > 1:
        first_char = c_list.pop(0)
        last_char = c_list.pop(-1)
        if first_char != last_char:
            return False
    return True'''
        self.initialize = '''1. Initialize
s = '{}'
c_list = []
2. Loop1
'''
        self.enter1 = '''```
while s:
```
s = '{}'
enter the loop
2.1 One iteration
'''
        self.not_enter1 = '''```
while s:
```
s = '{}'
do not enter
'''
        self.last_c = '''```
c = s[-1]
s = s[:-1]
```
s = '{}'
now,
c = '{}'
s = '{}'
'''
        self.if_isalnum = '''```
if c.isalnum():
    c_list.append(c.lower())
```
c = '{}'
c.isalnum() == True
enter
c_list = {}
c.lower() = '{}'
now,
c_list = {}
'''
        self.not_if_isalnum = '''```
if c.isalnum():
    c_list.append(c.lower())
```
c = '{}'
c.isalnum() == False
do not enter
'''
        self.loop2 = '''3. Loop2
'''
        self.enter2 = '''```
while len(c_list) > 1:
```
c_list = {}
enter the loop
3.1 One iteration
'''
        self.not_enter2 = '''```
while len(c_list) > 1:
```
c_list = {}
do not enter
'''
        self.get_char = '''```
first_char = c_list.pop(0)
last_char = c_list.pop(-1)
```
c_list = {}
first_char = '{}'
last_char = '{}'
now,
c_list = {}
'''
        self.if_neq = '''```
if first_char != last_char:
    return False
```
first_char = '{}'
last_char = '{}'
(first_char != last_char) == True
enter if
return False
So the answer is False'''
        self.if_eq = '''```
if first_char != last_char:
    return False
```
first_char = '{}'
last_char = '{}'
(first_char != last_char) == False
do not enter
'''
        self.return_result = '''```
return True
```
So the answer is True
'''

    def gen_data_from_len(self, length: int) -> dict:
        '''
        return datapoint of given length,
        datapoint is a dict of keys including "question","s"
        '''
        import string
        s = ''
        half_length = (length+1) // 2
        for i in range(half_length):
            # upper case + lower case + space + dot + colon
            s += random.choice(string.ascii_lowercase + string.ascii_uppercase + ' .:')
        repeat_length = length - half_length
        s += s[:repeat_length][::-1]
        
        # modify some chars to make it not a palindrome
        if random.random() < 0.5:
            num = max(2, int(length * 0.3))
            num = min(num, length)
            idx = random.sample(range(length), num)
            # print(idx, s)
            s = list(s)
            for i in idx:
                s[i] = random.choice(string.ascii_lowercase + string.ascii_uppercase + ' .:')
            s = ''.join(s)
            

        question = f'''s = "{s}". Is s a palindrome?'''
        return {"question": question,
                "s": s}
        
    def rfft_IO(self, data: dict) -> dict:
        '''
        data: a datapoint from gen_data_from_len
        return rfft input-output of given data
        '''
        instruction = "Follow the given rule to solve the question.\nrule:"
        rule = self.rule
        question = data["question"]
        input = instruction + rule + "\n\nQ: " + question
        s = data['s']
        c_list = []
        output = self.initialize.format(s)
        while s:
            output += self.enter1.format(s)
            output += self.last_c.format(s, s[-1], s[:-1])
            c = s[-1]
            s = s[:-1]
            if c.isalnum():
                output += self.if_isalnum.format(c, c_list, c.lower(), c_list + [c.lower()])
                c_list.append(c.lower())
            else:
                output += self.not_if_isalnum.format(c)
        output += self.not_enter1.format(s)
        output += self.loop2
        while len(c_list) > 1:
            output += self.enter2.format(c_list)
            output += self.get_char.format(c_list, c_list[0], c_list[-1], c_list[1:-1])
            first_char = c_list.pop(0)
            last_char = c_list.pop(-1)
            if first_char != last_char:
                output += self.if_neq.format(first_char, last_char)
                return {"input": input,
                        "output": output,
                        "answer": False}
            else:
                output += self.if_eq.format(first_char, last_char)
        output += self.not_enter2.format(c_list)
        output += self.return_result
        return {"input": input,
                "output": output,
                "answer": True}
    
    def natural_IO(self, data: dict) -> dict:
        '''
        data: a datapoint from gen_data_from_len
        return cot input-output of given data
        '''
        question = data["question"]
        s = data['s']
        output = "# remove all non-alphanumeric characters and convert to lowercase\n"
        s = ''.join([c.lower() for c in s if c.isalnum()])
        output += f"s = {s}\n"
        answer = True
        half = (len(s)+1) // 2
        for i in range(1, half+1):
            output += f"# check if the {i}-th character from left is equal to the {i}-th character from right\n"
            output += f"{i}-th character from left: {s[i-1]}\n{i}-th character from right: {s[-i]}\n"
            if s[i-1] != s[-i]:
                output += f"not equal, so the answer is False"
                answer = False
                break
            else:
                output += f"equal\n"
        if answer:
            output += f"all characters are equal, so the answer is True"
        return {"input": question,
                "output": output,
                "answer": answer}
        
    def cot_IO(self, data: dict) -> dict:
        '''
        data: a datapoint from gen_data_from_len
        return cot input-output of given data
        '''
        question = data["question"]
        s = data['s']
        output = "Remove all non-alphanumeric characters and convert to lowercase: "
        s = ''.join([c.lower() for c in s if c.isalnum()])
        output += f"'{s}'\n"
        answer = True
        half = (len(s)+1) // 2
        output += 'Compare characters from both ends:\n'
        for i in range(1, half+1):
            if i == 1:
                output += f"first character from left: '{s[i-1]}' {'==' if s[i-1] == s[-i] else '!='} first character from right: '{s[-i]}'\n"
            else:
                output += f"next character from left: '{s[i-1]}' {'==' if s[i-1] == s[-i] else '!='} next character from right: '{s[-i]}'\n"
            if s[i-1] != s[-i]:
                output += f"So the answer is False"
                answer = False
                break
        if answer:
            output += f"All characters are equal, so the answer is True"
        return {"input": question,
                "output": output,
                "answer": answer}

if __name__ == "__main__":
    dg = Dataset_Generator()
    data = dg.gen_data_from_len(6)
    data = dg.cot_IO(data)
    print(data['input'], '\n')
    print(data['output'])