import re
import torch

answer_pattern = re.compile(
    r"<answer>(.*?)</answer>",
    flags=re.DOTALL
)

def get_answer(completion):
    answer_match = re.search(answer_pattern, completion)
    answer = answer_match.group(1) if answer_match else None
    return answer


def get_reward(completion, oracle_answer):
    answer = get_answer(completion)

    reward = 0.0
    is_correct = False
    is_answered = False
    if answer is not None:
        answer = ''.join([c for c in answer if c.isupper()])
        if len(answer) == 1:
            is_answered = True
        if answer == oracle_answer:
            reward = 1.0
            is_correct = True
    return reward, answer, is_correct, is_answered


def postprocess_examples(examples, tokenizer, is_vision, min_length=0):

    question = examples[0]['question']
    oracle_answer = examples[0]['answer']
    rewards = list()
    answers = list()
    num_correct = 0
    num_answered = 0

    for i, example in enumerate(examples):
        assert example['question'] == question
        assert example['num_options'] == examples[0]['num_options']
        assert example['answer'] == examples[0]['answer']

        generated_ids = example['generated_ids']
        example['idx'] += f"_{i}"

        if is_vision:
            completion = tokenizer.tokenizer.decode(generated_ids, skip_special_tokens=True)
        else:
            completion = tokenizer.decode(generated_ids, skip_special_tokens=True)

        example['completion'] = completion

        reward, answer, is_correct, is_answered = get_reward(completion, oracle_answer)
        num_correct += int(is_correct)
        num_answered += int(is_answered)

        if min_length > 0 and len(generated_ids) < min_length:
            reward = 0.0
            answer = None
            is_correct = False
            is_answered = False

        rewards.append(reward)
        answers.append(answer)
        example['is_answered'] = is_answered
        example['is_correct'] = is_correct
        example['model_answer'] = answer
        example['entropy'] = torch.tensor(example['entropy']).mean()

    print('-' * 100)
    print('Question:')
    print(examples[0]['question'])
    print('Completion:')
    print(examples[0]['completion'])
    print('Model Answer:', examples[0]['model_answer'])
    print(f"rollout answer='{examples[0]['answer']}', returns={sum(rewards):.2f}, num_answered={num_answered}, generated_ids={max(len(e['generated_ids']) for e in examples)}")
    return rewards, answers, oracle_answer

