import os
from langchain import PromptTemplate, LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
import openai                
from huggingface_hub import InferenceClient
from sklearn.cluster import DBSCAN


from src.dataloaders.dataloader_fixed import MyData
from sklearn.cluster import KMeans
import pickle

import time

from src.utils.utils import (
                            clustering_score,
                            hash_dict,
                            parse_output
                            )


from sklearn.preprocessing import LabelEncoder

import pandas as pd
from datetime import datetime

from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)
from omegaconf import OmegaConf
from tqdm import tqdm
import re
import wandb
import random

from embeddings import Embedder, EmbeddingDB

    
class IntentGPT():

    def __init__(self, args):        
        self.args = args
        self.data = self.get_data()

        self.class_map = {v: k for k, v in self.data.all_label_map.items()}
        
        self.init_llm()
        self.create_directories()

        if self.args.use_wandb:
            wandb.init(name=f"{self.readable_id}_{self.exp_id}", project=self.args.project, config=dict(self.args))

        if self.use_embeddings():
            self.embedder = Embedder(self.args.intent_embeddings)

        if self.args.semantic_feedback_discovery or self.args.semantic_few_shot:
            self.test_intents_embeddings = EmbeddingDB(self.embedder, db_file="test_intents_embeddings.sqlite")

        if self.args.feedback_discovered_categories:
            self.args.discovered_categories_path = f"{self.results_path}/discovered_categories.txt"
            if self.args.start_with_known_categories:
                self.known_intents = [str(it) for it in list(self.data.train_label_map.keys())]
                with open(args.discovered_categories_path, 'w') as f:
                    for item in self.known_intents:
                        f.write(f"{item}\n")
                if self.args.semantic_feedback_discovery:
                    # Here we create the Embedding database
                    self.known_intents_embeddings = EmbeddingDB(self.embedder, self.known_intents, db_file="known_intents_embeddings.sqlite")
            else:
                open(args.discovered_categories_path, 'w').close() # Create the file (empty db)            

        self.init_few_shot_pool()

        if not self.args.sample_few_shot_every_iteration:
            self.current_few_shot_samples = self.get_few_shot_examples()

    def use_embeddings(self):
        return self.args.intent_embeddings or self.args.use_embeddings_for_labels or self.args.semantic_few_shot or self.args.semantic_feedback_discovery
    
    def get_hard_prompt(self):
        dataset_name = self.args.dataset
        compute_hard_samples = self.args.compute_hard_samples
        N_hard_samples = self.args.N_hard_samples
        samples_x_class = self.args.samples_x_class
        force_recompute = self.args.force_recompute
        labeled_ratio = self.args.labeled_ratio
        known_cls_ratio = self.args.known_cls_ratio

        if compute_hard_samples:
            data_config_id = f"{labeled_ratio}_labeled_{known_cls_ratio}_known_intents_{N_hard_samples}_hard_samples_{samples_x_class}_samples_x_class"
        else:
            data_config_id = f"{labeled_ratio}_labeled_{known_cls_ratio}_known_intents_hard_prompt_{samples_x_class}_samples_x_class"
            hard_samples = None
            
        path_store_prompts = f"prompts/{dataset_name}/{data_config_id}"
        os.makedirs(path_store_prompts, exist_ok=True)

        if compute_hard_samples:
            hard_examples_path = os.path.join(path_store_prompts, 'hard_examples.txt')
        prompt_path = os.path.join(path_store_prompts, 'prompt.txt')

        # Check if prompt exists
        if os.path.exists(prompt_path) and not force_recompute:
            # extract hard_samples line by line and re find utterance and intent
            if compute_hard_samples:
                with open(hard_examples_path, 'r') as f:
                    content = f.read()
                if not content.endswith('\n'):
                    content += '\n'
                hard_samples = re.findall(r'Utterance: (.*?), Intent: (.*?)\n', content)

            with open(prompt_path, 'r') as f:
                prompt = f.read()
            
            return hard_samples, prompt
        else:
            
            print("Generating prompt...")

            # Dictionary of {class: [utterances]}
            classes = {}
            for sample in self.data.train_labeled_loader.dataset:
                # add sample to classes using label as key
                label = int(sample['labels'])
                if label not in classes:
                    classes[label] = []
                utterance = self.data.tokenizer.decode(sample['input_ids'], skip_special_tokens = True)
                classes[label].append(utterance)
            
            # order classes by key
            classes = dict(sorted(classes.items()))

            # Define class id to class name mapping
            class_map = {v: k for k, v in self.data.train_label_map.items()}

            # select num_samples_per_class samples per class
            dict_class_examples = {}
            for class_id, class_name in class_map.items():
                dict_class_examples[class_name] = classes[class_id][:samples_x_class]

            if compute_hard_samples:
                prompt_template = '''You are a helpful assistant and an expert in natural language processing and specialize in open set discovery intent classification. This task involves assigning textual utterances to specific intents, some of which are pre-defined (known) and others are not (unknown). \nAs an expert in open set discovery intent classification, your task is to provide a list of {N_hard_samples} pairs of utterances and intents that are the most difficult to classify. These pairs should consist of one utterance and its corresponding intent. Make sure that the final list contains high variety of intents. The hard samples you propose need to be extracted from the original list.
                
                To select the most difficult pairs, consider the following criteria:
                1. Ambiguity: Choose utterances that can be interpreted in multiple ways, making it challenging to determine the correct intent.
                2. Contextual complexity: Select pairs where the intent is highly dependent on the context provided in the utterance or surrounding dialogue.
                3. Lack of explicit keywords: Include pairs with utterances that do not contain obvious keywords or explicit indicators of the intent.
                4. Similarities among intents: Include pairs where intents have overlapping or similar meanings, making it challenging to differentiate between them.

                Once you have selected the difficult pairs, and have acquired sufficient context about the problem, provide a detailed prompt for an AI language model to solve the task of open set intent discovery, maximizing the model's performance in the task. Provide effective guidelines about how to solve the task in the prompt. 

                EXAMPLES:
                
                {train_examples}

                You must respond using the following format:    
                HARD EXAMPLES ({N_hard_samples} pairs): Utterance: <utterance>, Intent: <intent>      
                PROMPT: <prompt>
                '''
            else:
                # read file
                prompt_design_prompt = 'prompts/common_prompts/initial_prompt_gpt4.txt'
                with open(prompt_design_prompt, 'r') as f:
                    prompt_template = f.read()
                # prompt_template = '''You are a helpful assistant and an expert in natural language processing. You specialize in intent discovery. This task involves assigning textual utterances to specific intents, some of which are pre-defined and others are not. \nYou will be presented with a set of examples from the dataset, and need to acquire sufficient context and knowledge about the problem and specific data domain. As an expert, provide a detailed prompt for an AI language model to solve the task of open set intent discovery, maximizing the model's performance in the task. Provide effective guidelines about how to solve the task in the prompt.\nEXAMPLES:{train_examples}.\nYou must respond using the following format:\nPROMPT: <prompt>'''
            
            # Each example is a string of the form "Utterance: <utterance>, Intent: <intent>"
            train_examples = []
            for class_name, utterances in dict_class_examples.items():
                for utterance in utterances:
                    train_examples.append(f'Utterance: {utterance}, Intent: {class_name}')

            train_examples = '\n'.join(train_examples)

            if compute_hard_samples:
                long_prompt = PromptTemplate(template=prompt_template, input_variables=["train_examples", "N_hard_samples"])
            else:
                long_prompt = PromptTemplate(template=prompt_template, input_variables=["train_examples"])

            llm_openai = OpenAI(model_name="gpt-4", temperature=0.2)
            llm_chain = LLMChain(prompt=long_prompt, llm=llm_openai)
            if self.args.verbose:
                if compute_hard_samples:
                    print(long_prompt.format(train_examples=train_examples, N_hard_samples=N_hard_samples))
                else:
                    print(long_prompt.format(train_examples=train_examples))

            @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(2))
            def completion_with_backoff(**kwargs):
                return llm_chain.run(**kwargs)
            
            if compute_hard_samples:
                output = completion_with_backoff(train_examples=train_examples, N_hard_samples=N_hard_samples)

                # Extracting hard samples
                hard_samples = re.findall(r'Utterance: (.*?), Intent: (.*?)\n', output)

                # Save results in txt
                with open(hard_examples_path, 'w') as f:
                    f.write('\n'.join([f'Utterance: {utterance}, Intent: {intent}' for utterance, intent in hard_samples]))     
            else:
                output = completion_with_backoff(train_examples=train_examples)

            # Extracting argumentation
            prompt_start = output.find('PROMPT:') + len('PROMPT:')
            prompt = output[prompt_start:].strip()
              
            with open(prompt_path, 'w') as f:
                f.write(prompt)

            return hard_samples, prompt

    def init_few_shot_pool(self):
        self.class_map_train = {v: k for k, v in self.data.train_label_map.items()}

        few_shot_pool_samples = None

        if self.args.prompt_type == 'hard_prompt':
            hard_samples, self.hard_prompt = self.get_hard_prompt()
            few_shot_pool_samples = hard_samples

        if few_shot_pool_samples == None:
            self.few_shot_pool = []
            for i in range(self.data.train_labeled_loader.dataset.__len__()):
                sample = self.data.train_labeled_loader.dataset[i]
                utterance = self.data.tokenizer.decode(sample['input_ids'], skip_special_tokens = True)
                label = self.class_map_train[int(sample['labels'])]
                self.few_shot_pool.append((utterance, label))
        else:
            self.few_shot_pool = hard_samples

        if self.args.semantic_few_shot:
            # create and embedding database for few-shot samples
            few_shot_pool = []
            for sample in self.few_shot_pool:
                utterance = str(sample[0].replace('"', ''))
                few_shot_pool.append(utterance)
            self.few_shot_embeddings = EmbeddingDB(self.embedder, few_shot_pool, db_file="few_shot_embeddings.sqlite")

    def create_directories(self):

        # --------------- EXPERIMENT INFO --------------------
        self.timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        self.exp_id = hash_dict(dict(self.args)) +"_"+str(self.timestamp)

        # Create a directory for storing results
        results_path = f"results_v2/{self.args.dataset}"
        feedpack_string = 'feedback' if self.args.feedback_discovered_categories else 'no_feedback'
        self.readable_id = f"{self.args.prompt_type}_{self.args.num_few_shot_examples}_shot_{self.args.known_cls_ratio}_clsratio_{feedpack_string}_temp_{self.args.temperature}_{self.args.gpt_model}"
        self.results_path = f"{results_path}/{self.readable_id}_{self.exp_id}"
        os.makedirs(self.results_path, exist_ok=True)

        if self.args.verbose:
            print(f"Results will be saved in {results_path}")
        
        # Store config file
        with open(os.path.join(self.results_path, 'config.yaml'), 'w') as f:
            OmegaConf.save(self.args, f)

    @retry(wait=wait_random_exponential(min=2, max=200), stop=stop_after_attempt(100))
    def completion_with_backoff(self, **kwargs):
        return self.llm_chain.run(**kwargs)
    
    def get_prompt_template(self):
        return """{prompt}\nCONTEXT EXAMPLES:\n{few_shot_examples}\nRESPONSE FORMAT:\nID: <i>, Utterance: <content>, Intent: <intent> Use the same ID in the test example.\nTEST EXAMPLES (to predict):\n{test_examples}\nYour response:"""

    def init_llm(self):
        # -------------- OPENAI/LANGCHAIN API---------------
        # Load prompt template
        prompt_template = self.get_prompt_template()
        self.long_prompt = PromptTemplate(template=prompt_template, input_variables=["few_shot_examples", "test_examples", "prompt"])
        if 'gpt' in self.args.gpt_model: 
            llm_openai = ChatOpenAI(model_name=self.args.gpt_model, temperature=self.args.temperature)
            self.llm_chain = LLMChain(prompt=self.long_prompt, llm=llm_openai)
        else:
            if self.args.gpt_model == 'toolkit-llama':
                self.tgi_client = InferenceClient(model="")
            else:
                openai.api_key = ''
                openai.api_base = "https://api.endpoints.anyscale.com/v1"
                self.llm_chain = self.args.gpt_model

    def get_test_samples(self, i, batch, b_size, return_formatted=True):
        test_examples = []        
        gt_labels = []          
        ids = [] 
        for j in range(b_size):
            id = i * self.args.batch_size + j
            ids.append(id)
            utterance = self.data.tokenizer.decode(batch['input_ids'][j], skip_special_tokens = True)
            label = self.class_map[int(batch['labels'][j])]
            test_examples.append((id, utterance))

            if self.args.verbose:
                print(f"test example {id}: {utterance}")
            gt_labels.append(label)

        if return_formatted:
            test_examples = '\n'.join([f"ID:{id}, {utterance}" for id, utterance in test_examples])

        return test_examples, gt_labels, ids        

    def get_few_shot_examples(self):
        if self.args.k_nearest_few_shot == 0:
            return []
        if self.args.semantic_few_shot:
            test_samples_embeddings = []
            for sample in self.current_test_samples:
                utterance = sample[1]
                embedding = self.test_intents_embeddings.get_embedding(utterance)
                test_samples_embeddings.append(embedding)

            distances, indices = self.few_shot_embeddings.get_k_nearest(test_samples_embeddings, self.args.k_nearest_few_shot)
    
            few_shot_examples = []
            for i in range(len(indices)):
                for j, ind in enumerate(indices[i]):
                    few_shot_examples.append((self.few_shot_pool[ind][0], self.few_shot_pool[ind][1], distances[i][j]))

            # Keep only unique sentences, updating the distance if the current one is smaller    
            unique_examples = {}
            for sentence, intent, distance in few_shot_examples:
                if sentence not in unique_examples:
                    unique_examples[sentence] = (intent, distance)
                else:
                    # Update if current distance is smaller
                    if distance < unique_examples[sentence][1]:
                        unique_examples[sentence] = (intent, distance)
            
            # Sort by distance and keep the top k examples
            few_shot_examples = sorted(unique_examples.items(), key=lambda item: item[1][1])[:self.args.num_few_shot_examples]
            few_shot_examples = [(sentence, intent) for sentence, (intent, distance) in few_shot_examples]

        else:
            few_shot_examples = []
            random.shuffle(self.few_shot_pool)
            for i in range(self.args.num_few_shot_examples):
                few_shot_examples.append(self.few_shot_pool[i])

        ret = '\n'.join([f"Utterance: {utterance.strip()}, Intent: {intent}" for index, (utterance, intent) in enumerate(few_shot_examples)])
        return ret
    
    def get_prompt(self):
        prompt_type = self.args.prompt_type
        # This task involves assigning textual utterances to specific intents, some of which are pre-defined and others are not. You have to assign utterances to known intents, or create new intents if none of the known intents fit with the utterance. 

        base_prompt = """You are a helpful assistant and an expert in natural language processing and specialize in the task of intent detection."""
        additional_description = """Make sure each intent is only between one and three words, and as short and reusable as possible. Use the same format as the context examples. Don't classify the examples below CONTEXT EXAMPLES. Only classify the test examples below TEST EXAMPLES. You are prohibited to assign intents to 'unknown'. Instead, create a new intent."""
        if prompt_type == "base_prompt":
            prompt = f"""{base_prompt} The intent can be one of the pre-defined intents or a new one that you create based on the context and knowledge about the problem and specific data domain. You should never assign an utterance to 'unknown'. For each utterance, analyze the context and the specific request or action implied. If the utterance matches a known intent, assign it to that intent. If it doesn't match any known intent, create a new intent that accurately represents the request or action implied by the utterance. Remember, the goal is to understand the user's intent as accurately as possible. Be aware of the known intents and reuse them as much as possible, but don't hesitate to create new intents when necessary. {additional_description}""" 
            
        elif prompt_type == "simple_guidelines":
            prompt = f"""
                {base_prompt}
                You must follow the following guidelines when predicting the intent:
                1) Identify Core Request: Look for verbs signaling the user's intent, such as "update", "change", "want to know", "maintainance" etc.
                2) Spot Key Subjects: Identify important objects or subjects like "pin number", "song", "time", "schedule", etc.
                3) Contextual Clues: Observe the surrounding details or scenario which can hint at the intent.
                4) Geographical/Time Terms: Note references to specific locations or time.
                5) Question vs. Declaration: Determine if the utterance is asking a question or stating a need/desire.
                6) Try to make the intent as short as possible, but make sure it is still descriptive.
                7) You can generate your prediction using some words from the utterance.
                8) Make sure the intent you choose is the most accurate intent for the utterance.
                9) Try to not generate intents that are too similar to each other.
                10) REUSE DISCOVERED INTENTS AS MUCH AS POSSIBLE, but you can be fine-grained if needed. 
                11) Try to be generate intents in an extractive manner, using words from the utterance.\n
                {additional_description}
            """
            
        elif prompt_type == "hard_prompt":
            prompt = f"""{base_prompt}\n{self.hard_prompt}\n{additional_description}"""
        else:
            raise ValueError(f"Prompt {prompt_type} not found.")
        
        if self.args.feedback_discovered_categories:
            if self.args.semantic_feedback_discovery:
                self.current_sfd = self.get_semantic_feedback_discovery()
            else:
                feedback_discovery = self.known_intents
              
            if self.args.semantic_feedback_discovery_after_test_sample:
                prompt += """Don't discover new intents if you have already discovered an intent that is similar to the one you are trying to create. After every sample in TEST EXAMPLES, I will give you some known intents that are semantically similar to the one you are trying to predict, but you are encouraged to create a new intent if none the intents fit with the utterance. You must decide if reusing one of those intents or create a new one."""
            else:
                if self.args.semantic_feedback_discovery:
                    prompt += """Don't discover a new intent if you have already discovered one that is similar. The following is a list of some of the discovered intents so far, that can be semantically similar to the examples you are trying to predict: Try to reuse the intents as much as possible, with the objective of having the least amount of intents possible."""

                    feedback_discovery = []
                    for intent in self.current_sfd:
                        feedback_discovery.append(intent)
                    feedback_discovery = list(set(feedback_discovery))

                else:
                    prompt += """Don't discover a new intent if you have already discovered one that is similar. Make sure that the intents are not very generic, you can be fine-grained. Use the following list of known intents to keep reference, reuse them as much as possible: """

                discovered_categories = ', '.join([f"{intent}" for intent in feedback_discovery])

                # Add discovered categories to the prompt
                prompt += discovered_categories

        return prompt

    def get_semantic_feedback_discovery(self):
        # Compute embeddings of test samples
        test_samples_embeddings = []
        for sample in self.current_test_samples:
            utterance = sample[1]
            embedding = self.test_intents_embeddings.get_embedding(utterance)
            test_samples_embeddings.append(embedding)

        # Extract K nearest neighbors of known intents for each test sample
        distances, indices = self.known_intents_embeddings.get_k_nearest(test_samples_embeddings, self.args.k_nearest_feedback) 

        feedback_discovery = []
        for i in range(len(indices)):
            for j, ind in enumerate(indices[i]):
                feedback_discovery.append((self.known_intents[ind], distances[i][j]))
        
        # Keep only unique sentences, updating the distance if the current one is smaller    
        unique_examples = {}
        for intent, distance in feedback_discovery:
            if intent not in unique_examples:
                unique_examples[intent] = (intent, distance)
            else:
                # Update if current distance is smaller
                if distance < unique_examples[intent][1]:
                    unique_examples[intent] = (intent, distance)
        
        # Sort by distance and keep the top k examples
        feedback_discovery = sorted(unique_examples.items(), key=lambda item: item[1][1])[:self.args.number_semantic_feedback]
        feedback_discovery = [intent for intent, distance in feedback_discovery]
        return feedback_discovery

    def llama_completion(self, prompt):
        while True:
            try:
                if self.args.gpt_model == 'toolkit-llama':
                    inference_args = {
                        "prompt": prompt,
                        "max_new_tokens": 1000,
                        "temperature": self.args.temperature,
                        # "stop_sequences": ["\n\n\n"],
                    }
                    out = self.tgi_client.text_generation(**inference_args)

                else:
                    completion = openai.ChatCompletion.create(
                        model=self.args.gpt_model,
                        messages=[{"role": "user", "content": prompt}],
                        temperature=self.args.temperature,
                        top_p = 0.9 if self.args.temperature > 0.0 else 1.0,
                    )
                    out = completion.choices[0].message.content
                return out
            except openai.error.RateLimitError as e:
                print("RateLimitError, Sleeping for 100 seconds...")
                time.sleep(100)
            except openai.error.APIError as e:
                print(f"APIError, {e}\nSleeping for 100 seconds...")
                time.sleep(100)
            except Exception as e:
                print(f"{e}, Sleeping for 100 seconds...")
                time.sleep(100)
            
    def run(self):
        self.final_output = []
        for i, batch in tqdm(enumerate(self.data.test_dataloader)):
            b_size = batch['input_ids'].shape[0]

            test_examples, gt_labels, _ = self.get_test_samples(i, batch, b_size, return_formatted=False)
            self.current_test_samples, self.current_gt_labels = test_examples, gt_labels
            
            if self.args.sample_few_shot_every_iteration:
                few_shot_examples = self.get_few_shot_examples()
            else:
                few_shot_examples = self.current_few_shot_samples
            
            prompt = self.get_prompt()
            if self.args.semantic_feedback_discovery_after_test_sample and self.args.semantic_feedback_discovery:
                test_examples = '\n'.join([f"ID:{id}, {utterance}, Similar Intents: {', '.join(self.current_sfd[k])}" for k, (id, utterance) in enumerate(test_examples)])
            else:
                test_examples = '\n'.join([f"ID:{id}, {utterance}" for id, utterance in test_examples])
            
            if self.args.verbose:
                print(self.long_prompt.format(few_shot_examples=few_shot_examples, test_examples=test_examples, prompt=prompt))
                print(f" Prompt length: {len(self.long_prompt.format(few_shot_examples=few_shot_examples, test_examples=test_examples, prompt=prompt))}")
            
            
            if 'gpt' in self.args.gpt_model:
                output = self.completion_with_backoff(few_shot_examples=few_shot_examples, test_examples=test_examples, prompt=prompt)
            else:
                llm_prompt = self.long_prompt.format(few_shot_examples=few_shot_examples, test_examples=test_examples, prompt=prompt)
                output = self.llama_completion(llm_prompt)                
                
            parsed_output = parse_output(output)

            if self.args.feedback_discovered_categories:
                discovered_categories = list(set([item['intent'] for item in parsed_output]))
                for intent in discovered_categories:
                    if intent not in self.known_intents:
                        self.known_intents.append(intent)
                        if self.args.semantic_feedback_discovery: # Add the discovered embedding to the database
                            self.known_intents_embeddings.get_embedding(intent)

                # Update the file that stores the discovered categories (DB)
                with open(self.args.discovered_categories_path, 'w') as f:
                    for item in self.known_intents:
                        f.write(f"{item}\n")

            # Append labels to parsed output
            for k in range(len(parsed_output)):
                parsed_output[k]['gt_label'] = gt_labels[k]
                if self.args.verbose:
                    print(f"Predicted sample {parsed_output[k]['id']}: {parsed_output[k]['utterance']}, {parsed_output[k]['intent']}, {parsed_output[k]['gt_label']}")
                self.final_output.append(parsed_output[k])

            # Save gt intents in a txt
            gt_intents = [item['gt_label'] for item in self.final_output]
            gt_intents = list(set(gt_intents))
            with open(os.path.join(self.results_path, 'gt_intents.txt'), 'w') as f:
                for item in gt_intents:
                    f.write(f"{item}\n")

    def get_embeddings(self, text_list, embeddings_path):
        
        embeddings = None

        # Check if the embeddings have already been computed
        if os.path.exists(embeddings_path):
            with open(embeddings_path, 'rb') as handle:
                embedding_dict = pickle.load(handle)

            # Valiate if all embeddings are there
            try:
                embeddings = [embedding_dict[label] for label in text_list]
            except KeyError:
                embeddings = None
        
        if embeddings is None:
            unique_labels = list(set(text_list))
            embeddings = self.embedder.embed_texts(unique_labels)
            embedding_dict = {label: embedding for label, embedding in zip(unique_labels, embeddings)}
            with open(embeddings_path, 'wb') as handle:
                pickle.dump(embedding_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
            embeddings = [embedding_dict[label] for label in text_list]
        return embeddings
    
    def get_clusters(self, y_pred, y_true):
        if self.args.intent_embeddings:
            # Compute embeddings
            pred_embeddings_path = os.path.join(self.results_path, f'pred_embeddings_{self.args.intent_embeddings}.pickle')
            true_embeddings_path = os.path.join(os.path.dirname(self.results_path), f'true_embeddings_{self.args.intent_embeddings}.pickle')

            pred_embeddings = self.get_embeddings(y_pred, pred_embeddings_path)
            
            # Compute number of clusters using DBSCAN
            dbscan = DBSCAN(eps=self.args.eps_dbscan, min_samples=2)
            dbscan.fit(pred_embeddings)
            labels = dbscan.labels_
            num_clusters = len(set(labels)) - (1 if -1 in labels else 0)

            clustering_model_pred = KMeans(n_clusters=num_clusters)
            clustering_model_pred.fit(pred_embeddings)
            labels_pred = clustering_model_pred.labels_

            if self.args.use_embeddings_for_labels:
                true_embeddings = self.get_embeddings(y_true, true_embeddings_path)
                clustering_model_true = KMeans(n_clusters=num_clusters)
                clustering_model_true.fit(true_embeddings)
                labels_true = clustering_model_true.labels_
            else:
                label_encoder = LabelEncoder()
                labels_true = label_encoder.fit_transform(y_true)            

        else:
            label_encoder = LabelEncoder()
            labels_true = label_encoder.fit_transform(y_true)
            labels_pred = label_encoder.fit_transform(y_pred)
        return labels_pred, labels_true

    def validate(self, results_path = None):
        if results_path is not None:
            self.results_path = results_path
            # load csv and create y_true and y_pred
            results = pd.read_csv(os.path.join(self.results_path, 'results.csv'))
            y_true = results['gt_label'].tolist()
            y_pred = results['intent'].tolist()
            
        else:
            # Generate csv tables with qualitative results
            df = pd.DataFrame(self.final_output)

            # Append to csv if exists
            if os.path.exists(os.path.join(self.results_path, 'results.csv')):
                df.to_csv(os.path.join(self.results_path, 'results.csv'), mode='a', index=False, header=False)
            else:
                df.to_csv(os.path.join(self.results_path, 'results.csv'), index=False)

            # Obtain clusterings
            y_true = [item['gt_label'] for item in self.final_output] 
            y_pred = [item['intent'] for item in self.final_output]
        
        labels_true, labels_pred = self.get_clusters(
            y_pred=y_pred, 
            y_true=y_true
        )
        # Compute clustering scores 
        res = clustering_score(labels_true, labels_pred) 

        # Generate csv table with quantitative results
        res['date'] = self.timestamp
        res['id'] = self.exp_id
        res['readable_id'] = self.readable_id
        res['num_discovered_intents'] = len(set(y_pred))
        res['results_path'] = self.results_path

        df = pd.DataFrame([res])
        if os.path.exists(os.path.join(self.results_path, 'scores.csv')):
            df.to_csv(os.path.join(self.results_path, 'scores.csv'), mode='a', index=False, header=False)
        else:
            df.to_csv(os.path.join(self.results_path, 'scores.csv'), index=False)

        # Report in wandb
        if self.args.use_wandb:
            wandb.log(res)

    def get_data(self):
        data = MyData(self.args, self.args)
        print(f"Number of test samples: {data.test_dataloader.dataset.__len__()}")
        return data
