from huggingface_hub import InferenceClient
from openai import OpenAI

bio_prompt = """
You are given a piece of text that is a part of a biography of an entity. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
Then, remove any redundant information.
Text: {text}

If this is not possible because the text is just a fragment of a sentence, return "Abstain".
If the text already claims a lack of knowledge about the topic, return "Abstain".
Only return the cleaned up text. Do not include any other text:
"""

fp_prompt = """
You are given a piece of text that is a part of a false presupposition task which includes outputting a list of items. 
This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
Then, remove any redundant information.
Text: {text}

The resulting list of items should be separated by semicolons with no other text.
If this list it not possible to generate, return "Abstain".
"""

hist_prompt = """
You are given a piece of text that is a part of a historical event task. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
Then, remove any redundant information.
Text: {text}

If this is not possible because the text is just a fragment of a sentence, return "Abstain".
If the text already claims a lack of knowledge about the topic, return "Abstain".
Only return the cleaned up text. Do not include any other text:
"""

refs_prompt = """
You are given a piece of text that is a part of a reference task. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
Then, remove any redundant information.
Text: {text}

If this is not possible because the text is just a fragment of a sentence, return "Abstain".
If the text already claims a lack of knowledge about the topic, return "Abstain".
Only return the cleaned up text. Do not include any other text:
"""

gpqa_prompt = """
You are given a piece of text that is a part of a graduate level question answering task. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
Then, remove any redundant information.
Text: {text}
Only return the cleaned up text. Do not include any other text:
"""

popqa_prompt =  """
You are given a piece of text that is a part of a paragraph which details facts related to an entity. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
Then, remove any redundant information.
Text: {text}

If this is not possible because the text is just a fragment of a sentence, return "Abstain".
If the text already claims a lack of knowledge about the topic, return "Abstain".
Only return the cleaned up text. Do not include any other text:
"""
task_to_prompt = {
    "bio": bio_prompt,
    "fp": fp_prompt,
    "hist": hist_prompt,
    "refs": refs_prompt,
    "gpqa": gpqa_prompt,
    "popqa": popqa_prompt
}

'''
Cleans up disfluencies in the draft response in consensus decoding. 

Args:
    text: The text to clean up.
    api: The API to use for cleaning up the text.
    task: The task : biography, false presupposition, historical event, reference, graduate question answering, paragraph question answering.
    model: The model to use for cleaning up the text.

Returns:
    A string of the cleaned up text.
'''

def clean_up_text(text: str, api: str, task: str, model: str = "gpt-4.1-mini", **kwargs):
    if api == "openai":
        client = OpenAI()
    elif api == "hf":
        tokenizer = kwargs.get("tokenizer")
        model = kwargs.get("hf_model")
        
        if tokenizer is None or model is None:
            raise ValueError("For 'hf', both 'tokenizer' and 'model' must be provided.")
        
        clean_up_prompt = task_to_prompt[task].format(text=text)
        
        messages = [{"role": "user", "content": clean_up_prompt}]
        input_ids = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(model.device)
        
        terminators = [ tokenizer.eos_token_id, ]
        outputs = model.generate(
            input_ids,
            max_new_tokens=500,
            do_sample=False,  
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=terminators,
        )
        
        return tokenizer.decode(
            outputs[0][input_ids.shape[-1]:],
            skip_special_tokens=True
        ).strip()
    else:
        raise ValueError(f"Invalid API: {api}")

    clean_up_prompt = task_to_prompt[task].format(text=text)

    completion = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": clean_up_prompt},
        ],
    )

    return completion.choices[0].message.content.strip()
