import random
from transformers import set_seed
import re
from .instruction_constants import *

set_seed(42)

def get_task_description(dataset_name):
    if dataset_name in TASK_DESCRIPTIONS:
        return random.choice(TASK_DESCRIPTIONS[dataset_name])
    return None

def get_format_description(dataset_name, instruction_format):
    if dataset_name in FORMAT_DESCRIPTIONS:
        if instruction_format in FORMAT_DESCRIPTIONS[dataset_name]:
            return random.choice(FORMAT_DESCRIPTIONS[dataset_name][instruction_format])
    return None

def get_few_shot_examples(dataset_name, num_shots):
    return None

def attach_gsm8k_prompt(example, dataset_name, instruction_format, few_shot_pool=None, num_shots=0):
    task_description = get_task_description(dataset_name)
    format_description = get_format_description(dataset_name, instruction_format)
    
    full_prompt = f'Follow the instruction to complete the task:\n{task_description}\n\n'
    
    if instruction_format != "no-format" and format_description:
        full_prompt += f'Format requirement: {format_description}\n\n'
    
    zs_prompt = full_prompt
    if num_shots > 0 and few_shot_pool:
        full_prompt += few_shot_pool.get_few_shot_examples(num_shots)
        full_prompt += "Task Question:\n"
    
    full_prompt += f'Question: {example["question"]}'
    zs_prompt += f'Question: {example["question"]}'
    
    example["full_prompt"] = full_prompt
    example["zs_prompt"] = zs_prompt
    return example

def attach_gsm8k_answer(example):
    idx = example["answer"].rfind("####") + 5
    example["label"] = str(int(example["answer"][idx:].strip().replace(",", "")))
    return example

def map_gsm8k(example, instruction_format, few_shot_pool=None, num_shots=0):
    example = attach_gsm8k_prompt(example, "gsm8k", instruction_format, few_shot_pool=few_shot_pool, num_shots=num_shots)
    example = attach_gsm8k_answer(example)
    return example


# summeval
def clean_summary(text):
    text = re.sub(r'\s+', ' ', text.strip())
    text = re.sub(r"\s+'s", "'s", text)
    text = re.sub(r'([.!?])\s+(\w)', lambda match: match.group(1) + ' ' + match.group(2).upper(), text)
    text = text[0].capitalize() + text[1:]
    text = re.sub(r'\s([,.;])', r'\1', text) 
    text = re.sub(r'([,.;])\s+', r'\1 ', text)
    text = re.sub(r'\$\s(\d)', r'$\1', text)  
    text = re.sub(r"``\s(\w+)\s''", r'"\1"', text) 
    text = re.sub(r'([(<{\[])\s', r'\1', text)
    text = re.sub(r'\s([)>}\]])', r'\1', text)
    return text

def attach_summeval_prompt(example, dataset_name, score_type, instruction_format, num_shots=0):
    task_description = get_task_description(f'{dataset_name}_{score_type}')
    format_description = get_format_description(dataset_name, instruction_format)
    machine_summary = clean_summary(example["machine_summary"])
    
    full_prompt = f'Follow the instruction to complete the task:\n{task_description}\n\n'
    
    if instruction_format != "no-format" and format_description:
        full_prompt += f'Format requirement: {format_description}\n\n'
    
    if num_shots > 0:
        full_prompt += f'{get_few_shot_examples(dataset_name, num_shots)}\n\n'
    
    full_prompt += (f'Now here are the news article and the summary.\n'
                    f'News article: {example["text"]}\n'
                    f'Summary: {machine_summary}')
    
    example["full_prompt"] = full_prompt
    return example

def attach_summeval_answer(example, score_type):
    example["label"] = example[score_type]
    return example

def map_summeval(example, instruction_format, score_type, num_shots=0):
    example = attach_summeval_prompt(example, "summeval", score_type, instruction_format, num_shots=num_shots)
    example = attach_summeval_answer(example, score_type)
    return example


# ace05-en
def attach_ace05_prompt(example, dataset_name, instruction_format, few_shot_pool=None, num_shots=0):
    task_type = "eae"
    task_description = get_task_description(task_type)
    format_description = get_format_description(task_type, instruction_format)
    info_dict = example["info_dict"]
    valid_roles_str = ", ".join(info_dict["valid roles"])
    
    full_prompt = f'Follow the instruction to complete the task:\n{task_description}\n\n'
    if instruction_format != "no-format" and format_description:
        full_prompt += f'Format requirement: {format_description}\n\n'
    
    full_prompt += (f'Here are the source sentence, event type, trigger word, and roles of interest.\n'
                    f'Source sentence: {example["passage"]}\n'
                    f'Event type: {info_dict["event type"]}. {info_dict["event description"]}\n'
                    f'Trigger word: {example["gold_event"]["trigger text"]}\n'
                    f'Roles of interest: {valid_roles_str}')
    
    example["full_prompt"] = full_prompt
    return example

def attach_ace05_prompt_text_ee(example, dataset_name, instruction_format, few_shot_pool=None, num_shots=0):
    info_dict = example["info_dict"]
    valid_roles_str = ", ".join(info_dict["valid roles"])
    
    full_prompt = ('You are an argument extractor designed to check for the presence of arguments regarding specific roles for an event in a sentence.\n'
                    f'Task Description: Identify all arguments related to the role {valid_roles_str} in the sentence. \
                    These arguments should have the semantic role corresponding to the given event trigger by the word span between [t] and [/t].\n')
    
    full_prompt += f'The event of interest is {info_dict["event type"]}. {info_dict["event description"]} Roles of interest: {valid_roles_str}\n\n'
    
    format_description = get_format_description("eae", instruction_format)
    if instruction_format != "no-format" and format_description:
        full_prompt += f'Format requirement: {format_description}\n\n'
    
    zs_prompt = full_prompt

    if num_shots > 0 and few_shot_pool:
        full_prompt += few_shot_pool.get_few_shot_examples(num_shots, info_dict["event type"])
    
    passage_with_tag = example["passage"]
    passage_with_tag = passage_with_tag.replace(f' {example["gold_event"]["trigger text"]} ', f' [t] {example["gold_event"]["trigger text"]} [/t] ')
    full_prompt += f'Question\nText: {passage_with_tag}'
    zs_prompt += f'Question\nText: {passage_with_tag}'

    example["full_prompt"] = full_prompt
    example["zs_prompt"] = zs_prompt
    return example

def get_ace05_few_shot_examples(num_shots):
    return None

def attach_ace05_answer(example):
    example["label"] = {
        "doc_id": example["doc_id"],
        "wnd_id": example["wnd_id"],
        "trigger": example["trigger"],
        "arguments": example["arguments"],
        "tokens": example["tokens"],
        "valid_roles": example["info_dict"]["valid roles"],
    }
    return example
def map_ace05(example, instruction_format, few_shot_pool=None, num_shots=0):
    example = attach_ace05_prompt_text_ee(example, "ace05-en", instruction_format, few_shot_pool=few_shot_pool, num_shots=num_shots)
    example = attach_ace05_answer(example)
    return example
