# The script is used to generate datasets.

import numpy as np
import random
from prompt_lc import Prompt
# with open("names.txt", "r") as f:
#     NAMES = f.read().strip().split()

class Dataset_Generator:
    def __init__(self) -> None:
        self.name = 'hamming_distance'

    def gen_data_from_len(self, length: int) -> dict:
        '''
        return datapoint of given length,
        datapoint is a dict of keys including `"question", "gt", ...`
        '''
        # example question: The Hamming distance between two integers is the number of positions at which the corresponding bits are different.\nGiven two integers x and y, return the Hamming distance between them.\nx = 1 and y = 4.
        # example answer: 2
        xBackup = x = random.randint(2**(length-1), 2**length-1)
        yBackup = y = random.randint(2**(length-1), 2**length-1)
        if length == 1:
            xBackup = x = random.randint(0, 1)
            yBackup = y = random.randint(0, 1)
        question = f'The Hamming distance between two integers is the number of positions at which the corresponding bits are different.\nGiven two integers x and y, return the Hamming distance between them.\n'
        question += f'x = {bin(x)} and y = {bin(y)}'
        gt = 0
        while x != 0 and y != 0:
            if x % 2 != y % 2:
                gt += 1
            x //= 2
            y //= 2
        return {"question": question, 
                "gt": gt,
                "number": (xBackup, yBackup)}

    def rfft_IO(self, data: dict) -> dict:
        '''
        return rfft input-output of given data
        '''
        instruction = "Follow the given rule to solve the question.\nrule:"
        P = Prompt("hamming_distance")
        rule = P.rule
        input = instruction + rule + "\n\nQ: " + data["question"]
        # rfft output
        (x, y) = data["number"]
        ans = 0
        output = P.initialize.format(bin(x), bin(y))
        while x != 0 or y != 0:
            if x % 2 != y % 2:
                ans += 1
                output += P.one_iteration_2_1_different.format(bin(x), bin(y), x % 2, y % 2, ans-1, ans, bin(x), bin(y), bin(x // 2), bin(y // 2))
            else:
                output += P.one_iteration_2_1_same.format(bin(x), bin(y), x % 2, y % 2, bin(x), bin(y), bin(x // 2), bin(y // 2))
            x //= 2
            y //= 2
        output += P.one_iteration_2_1_break.format()
        output += P.return_result.format(ans, ans)
        return {"input": input,
                "output": output,
                "answer": ans}
    
    def natural_IO(self, data: dict) -> dict:
        question = data["question"]
        (x, y) = data["number"]
        output = ''
        ans = 0
        while x != 0 or y != 0:
            if x % 2 != y % 2:
                ans += 1
                output += f"Since the last bit of x is {x % 2} and the last bit of y is {y % 2}, the Hamming distance increases by 1.\n"
            else:
                output += f"Since the last bit of x is {x % 2} and the last bit of y is {y % 2}, the Hamming distance remains the same.\n"
            x //= 2
            y //= 2
            output += f"x = {bin(x)} and y = {bin(y)}, distance={ans}\n"
        output += f"So the answer is {ans}\n"
        return {"input": question,
                "output": output,
                "answer": ans}
        
    def cot_IO(self, data: dict) -> dict:
        question = data["question"]
        (x, y) = data["number"]
        output = ''
        ans = 0
        fst = 1
        while x != 0 or y != 0:
            a1 = x % 2
            a2 = y % 2
            if fst:
                pos = "last"
            else:
                pos = "next"
            if a1 == a2:
                output += f"At the {pos} position, x has {a1} and y has {a2}, so they are the same.\n"
            else:
                output += f"At the {pos} position, x has {a1} and y has {a2}, so they are different. Count this as 1.\n"
                ans += 1
                
            x //= 2
            y //= 2
            fst = 0
        output += f"The total count of different positions is the Hamming distance between x and y.\n"
        output += f"So the answer is {ans}\n"
        return {"input": question,
                "output": output,
                "answer": ans}
        
if __name__ == "__main__":
    dg = Dataset_Generator()
    data = dg.gen_data_from_len(4)
    data = dg.natural_IO(data)
    print(data['input'])
    print(data['output'])