

import numpy as np

class DigitCompletion:

    def __init__(self, max_number: int, max_diff: int, max_num_in_response: int, seed=0):
        super().__init__()
        self.max_number = max_number
        self.max_diff = max_diff
        self.max_num_in_response = max_num_in_response
        assert self.max_num_in_response < 10
        assert self.max_number > 0
        assert self.max_diff > 0
        self.max_number_length = len(str(max_number))

        self._prompt_length = self.max_number_length * 2 + 4 + self.max_number_length

        self.np_rng = np.random.default_rng(seed=seed)

    def __str__(self):
        return (
            f"Prompt length: {self.prompt_length}. Response length: {self.response_length}, "
            f"Max number: {self.max_number}. Max diff: {self.max_diff}, "
            f"Max number in response: {self.max_num_in_response}"
        )

    def get_state(self):
        return {"rng": self.np_rng}

    def set_state(self, state):
        assert "rng" in state, "rng must be inside state"
        self.np_rng = state["rng"]

    @property
    def prompt_length(self):
        return self._prompt_length

    @property
    def response_length(self):

        return (self.max_num_in_response * self.max_number_length + (self.max_num_in_response - 1) + 1) * 2

    def add(self, a, b):
        return (a + b) % self.max_number

    def get_all_prompts(self):
        all_prompts = []
        for first_num in range(self.max_number + 1):
            for diff in range(0, self.max_diff + 1):
                second_num = self.add(first_num, diff)
                for num_to_complete in range(self.max_num_in_response + 1):
                    prompt = str(first_num) + "," + str(second_num) + f":{self.max_number},{num_to_complete}"
                    all_prompts.append(prompt)
        return all_prompts

    def sample_str_prompts(self):

        first_num = self.np_rng.integers(self.max_number + 1)
        diff = self.np_rng.integers(self.max_diff + 1)
        second_num = self.add(first_num, diff)
        num_to_complete = self.np_rng.integers(self.max_num_in_response + 1)
        prompt = str(first_num) + "," + str(second_num) + f":{self.max_number},{num_to_complete}"
        return prompt

    def sample_batch_str_prompts(self, batch_size):
        str_prompts = []
        for _ in range(batch_size):
            str_prompts.append(self.sample_str_prompts())
        return str_prompts

def compute_attention_mask(prompts, pad_token_id):
    mask = np.ones_like(prompts)
    mask[prompts == pad_token_id] = 0
    return mask

def compute_position_id_with_mask(mask):
    return np.clip(np.cumsum(mask, axis=-1) - 1, a_min=0, a_max=None)

def generate_ground_truth_response(prompt: str):
    num, info = prompt.split(":")
    num1, num2 = num.split(",")
    max_number, num_to_gen = info.split(",")
    num1 = int(num1)
    num2 = int(num2)
    max_number = int(max_number)
    num_to_gen = int(num_to_gen)
    diff = (num2 - num1) % max_number
    results = []
    last_num = num2
    for _ in range(num_to_gen):
        curr = (last_num + diff) % max_number
        results.append(str(curr))
        last_num = curr
    response = ",".join(results)
    return response

def compute_reward(prompt: str, response: str, sequence_reward=1.0):
    response_length = len(response)
    ground_truth_response = generate_ground_truth_response(prompt)
    per_token_reward = sequence_reward / (len(ground_truth_response) + 1)

    reward = np.zeros(response_length, dtype=np.float32)

    ground_truth_idx = 0
    for i in range(response_length):
        if ground_truth_idx == len(ground_truth_response):
            break

        ground_truth_response_token = ground_truth_response[ground_truth_idx]
        response_token = response[i]
        if ground_truth_response_token == response_token:
            reward[i] = per_token_reward
            ground_truth_idx += 1
        else:

            break

    return reward, {"ground_truth_response": ground_truth_response}

if __name__ == "__main__":
    task = DigitCompletion(max_number=20, max_diff=3, max_num_in_response=5)
    print(task.sample_str_prompts())

    prompt = "7,8:20,0"
    response = ""
    print(compute_reward(prompt, response))

    prompt = "7,8:20,0"
    response = "E000"
    print(compute_reward(prompt, response))

    prompt = "9,10:20,2"
    response = "11,12,13"
    print(compute_reward(prompt, response))
