from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate, FewShotPromptTemplate
import os
os.environ["HF_HOME"] = ""
os.environ["TORCH_HOME"] = ""
os.environ["OPENAI_API_KEY"] = ''

from langchain import PromptTemplate, LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI

import re

from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)

import random

base_prompt = """You are a helpful assistant and expert in natural language processing labeling. You are given a list of utterances and each one has a 
            specific intent associated with it. Group together utterances that have the same intent and generate an intent for the group.
            Make sure each intent is only between one and three words, but make them as short and reusable as possible. Make sure they follow the same format as the given examples, and don't create very similar intents. Try to reuse intents as much as possible, with the objective of having the least amount of intents possible.
            You should not group the examples that were given to you as context (the ones below CONTEXT EXAMPLES). 
            Only respond to the examples that are given to you as test examples (the examples below TEST EXAMPLES)."""

def get_prompt(args, prompt_name):
    if prompt_name == "base_prompt":
        prompt = base_prompt
        
    elif prompt_name == "simple_guidelines":
        prompt = f"""
            {base_prompt}
            You must follow the following guidelines when predicting the intent:
            1) Identify Core Request: Look for verbs signaling the user's intent, such as "update", "change", "want to know", "maintainance" etc.
            2) Spot Key Subjects: Identify important objects or subjects like "pin number", "song", "time", "schedule", etc.
            3) Contextual Clues: Observe the surrounding details or scenario which can hint at the intent.
            4) Geographical/Time Terms: Note references to specific locations or time.
            5) Question vs. Declaration: Determine if the utterance is asking a question or stating a need/desire.
            6) Try to make the intent as short as possible, but make sure it is still descriptive.
            7) You can generate your prediction using some words from the utterance.
            8) Make sure the intent you choose is the most accurate intent for the utterance.
            9) Try to not generate intents that are too similar to each other.
            10) REUSE DISCOVERED INTENTS AS MUCH AS POSSIBLE, but you can be fine-grained if needed. 
            11) Try to be generate intents in an extractive manner, using words from the utterance.
        """
        
    elif prompt_name == "hard_samples_prompt":
        prompt = get_hard_prompt(data = None, dataset_name = args.dataset)
        prompt = f"""
            {base_prompt}
            GUIDELINES: 
            {prompt}
        """
    else:
        raise ValueError(f"Prompt {prompt_name} not found.")
    
    if args.feedback_discovered_categories:
        prompt += """
        It is very important that you don't create a new category if you have already discovered a category that is similar to the one you are trying to create.
        In order to solve the task properly you must try to reuse the already discovered categories as much as possible, but make sure that the categories are not very generic. Try to be fine-grained when discovering new intents.
        For that purpose, here is the known categories. Reuse them as much as possible:
        DISCOVERED CATEGORIES"""

        # Read the file that contains the discovered categories
        with open(args.discovered_categories_path, 'r') as f:
            discovered_categories = f.read()

        # Add discovered categories to the prompt
        prompt += discovered_categories

    return prompt
    
def random_select_few_shot_samples(samples_list, num_few_shot):
        few_shot_hard_samples = []
        random.seed(42)
        random.shuffle(samples_list)
        for i in range(num_few_shot):
            few_shot_hard_samples.append(samples_list[i])
        return few_shot_hard_samples

def smart_few_shot_samples(samples_list, test_batch, num_few_shot):
    # TODO: extract embeddings of samples_list and use knn to select the closest to the ones in the test batch
    return random_select_few_shot_samples(samples_list, num_few_shot)

def get_hard_prompt():
    with open(prompt_path, 'w') as f:
        f.write(prompt)

def create_hard_examples_and_prompt(args, data):
    
    dataset_name = args.dataset
    # num_few_shot = args.num_few_shot_examples
    N_hard_samples = args.N_hard_samples
    samples_x_class = args.samples_x_class
    force_recompute = args.force_recompute_samples
    labeled_ratio = args.labeled_ratio
    known_cls_ratio = args.known_cls_ratio

    data_config_id = f"{labeled_ratio}_labeled_{known_cls_ratio}_known_intents_{N_hard_samples}_hard_samples_{samples_x_class}_samples_x_class"

    path_store_prompts = f"prompts/{dataset_name}/{data_config_id}"
    os.makedirs(path_store_prompts, exist_ok=True)

    hard_examples_path = os.path.join(path_store_prompts, 'hard_examples.txt')
    prompt_path = os.path.join(path_store_prompts, 'prompt.txt')

    # Check if prompt exists
    if os.path.exists(hard_examples_path) and os.path.exists(prompt_path) and not force_recompute:
        # extract hard_samples, argumentation, guidelines
        with open(hard_examples_path, 'r') as f:
            hard_samples = f.read()
        with open(prompt_path, 'r') as f:
            prompt = f.read()
        
        # few_shot_hard_samples = random_select_few_shot_samples(hard_samples, num_few_shot)

        # return prompt
    else:
        print("Generating prompt and samples...")

        # Dictionary of {class: [utterances]}
        classes = {}
        for sample in data.train_labeled_loader.dataset:
            # add sample to classes using label as key
            label = int(sample['labels'])
            if label not in classes:
                classes[label] = []
            utterance = data.tokenizer.decode(sample['input_ids'], skip_special_tokens = True)
            classes[label].append(utterance)
        
        # order classes by key
        classes = dict(sorted(classes.items()))

        # Define class id to class name mapping
        class_map = {v: k for k, v in data.train_label_map.items()}

        # select num_samples_per_class samples per class
        dict_class_examples = {}
        for class_id, class_name in class_map.items():
            dict_class_examples[class_name] = classes[class_id][:samples_x_class]

        # Generate prompt
        prompt_template = '''You are a helpful assistant and an expert in natural language processing and specialize in open set discovery intent classification. This task involves assigning textual utterances to specific intents, some of which are pre-defined (known) and others are not (unknown).        
        As an expert in open set discovery intent classification, your task is to provide a list of {N_hard_samples} pairs of utterances and intents that are the most difficult to classify. These pairs should consist of one utterance and its corresponding intent. The intent can be either pre-defined (known) or unknown.  
        
        To select the most difficult pairs, consider the following criteria:
        1. Ambiguity: Choose utterances that can be interpreted in multiple ways, making it challenging to determine the correct intent.
        2. Contextual complexity: Select pairs where the intent is highly dependent on the context provided in the utterance or surrounding dialogue.
        3. Lack of explicit keywords: Include pairs with utterances that do not contain obvious keywords or explicit indicators of the intent.
        4. Similarities among intents: Include pairs where intents have overlapping or similar meanings, making it challenging to differentiate between them.

        Once you have selected the difficult pairs, and have acquired sufficient context about the problem, provide a detailed prompt for an AI language model to solve the task of open set intent discovery, maximizing the model's performance in the task. Provide effective guidelines about how to solve the task in the prompt. 

        EXAMPLES:
        
        {train_examples}

        You must respond using the following format:    
        HARD EXAMPLES ({N_hard_samples} pairs): Utterance: <utterance>, Intent: <intent>      
        PROMPT: <prompt>
        '''
        
        # Each example is a string of the form "Utterance: <utterance>, Intent: <intent>"
        train_examples = []
        for class_name, utterances in dict_class_examples.items():
            for utterance in utterances:
                train_examples.append(f'Utterance: {utterance}, Intent: {class_name}')

        train_examples = '\n'.join(train_examples)

        long_prompt = PromptTemplate(template=prompt_template, input_variables=["train_examples", "N_hard_samples"])
        llm_openai = OpenAI(model_name="gpt-4", temperature=0.9)
        llm_chain = LLMChain(prompt=long_prompt, llm=llm_openai)
        verbose = False
        if verbose:
            print(long_prompt.format(train_examples=train_examples, N_hard_samples=N_hard_samples))

        @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(2))
        def completion_with_backoff(**kwargs):
            return llm_chain.run(**kwargs)
        
        output = completion_with_backoff(train_examples=train_examples, N_hard_samples=N_hard_samples)

        # Extracting hard samples
        hard_samples = re.findall(r'Utterance: (.*?), Intent: (.*?)\n', output)

        # Extracting argumentation
        prompt_start = output.find('PROMPT:') + len('PROMPT:')
        prompt = output[prompt_start:].strip()

        # Choose num_few_shot random samples from hard_samples
        # few_shot_hard_samples = random_select_few_shot_samples(hard_samples, num_few_shot)
        
        # Save results in txt
        with open(hard_examples_path, 'w') as f:
            f.write('\n'.join([f'Utterance: {utterance}, Intent: {intent}' for utterance, intent in hard_samples]))       
        with open(prompt_path, 'w') as f:
            f.write(prompt)

        # return few_shot_hard_samples, prompt


# --------------------- OTHER TEMPLATES --------------------- #
def get_few_shot_prompt():
    example_utterances = [
    {"utterance": "I want to deposit money", "intent": "deposit"},
    {"utterance": "I want to withdraw money", "intent": "withdraw"},
    {"utterance": "I want to transfer money", "intent": "transfer"},
    {"utterance": "I want to check my balance", "intent": "balance"},
]

    intent_formatter_template = """
    Utterance: {utterance}
    Intent: {intent}
    """

    intent_prompt_template = PromptTemplate(
        input_variables=["utterance", "intent"],
        template=intent_formatter_template,
    )

    few_shot_prompt_template = FewShotPromptTemplate(
        examples=example_utterances,
        example_prompt=intent_prompt_template,
        prefix="Return the intent of each utterance as a one to three word intent.",
        suffix="Utterance: {input}\nIntent:",
        input_variables=["input"],
        example_separator="\n",
    )

    return few_shot_prompt_template


def get_group_uttereance_prompt():
    """
    Returns the prompt for the group utterance classification task
    """
    numbered_utterances = """
    1. Schedule a meeting for 3pm tomorrow.
    2. Can you give me a trivia question about guitars?
    3. What is a fun fact about lizards?
    4. Can you schedule a meeting to talk with Jim later today?
    5. Set an alarm for 5am tomorrow.
    6. Please set an alarm for 8am on Saturday."""

    grouped_intents = """
    Group A: 2,3
    Intent: fun_fact

    Group B:1,4
    Intent: schedule_meeting

    Group C: 5,6
    Intent: set_alarm
    """

    few_shot_examples = [{
        "numbered_utterances": numbered_utterances,
        "grouped_intents": grouped_intents,
    }]

    utterance_examples_formatter = """
    {numbered_utterances}

    {grouped_intents}
    """

    group_prompt_template = PromptTemplate(
        input_variables=["numbered_utterances", "grouped_intents"],
        template=utterance_examples_formatter,
    )

    prefix = """
    You are given a list of utterances and each utterance has a specific intent associated with it.
    Group together utterances that have the same intent and generate an intent for the group.
    Make sure each intent is only between one and four words. There can be any number of groups made."""

    group_utterances_template = FewShotPromptTemplate(
        examples=few_shot_examples,
        example_prompt=group_prompt_template,
        prefix=prefix,
        suffix="{input}",
        input_variables=["input"],
        example_separator="\n",
    )

    return group_utterances_template