import numpy as np
import pandas as pd
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from scipy.optimize import linear_sum_assignment
import re
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.preprocessing import LabelEncoder
from langchain.embeddings import OpenAIEmbeddings
from sentence_transformers import SentenceTransformer
import os
import pickle
import hashlib

import warnings
warnings.filterwarnings("ignore")

def hungray_aligment(y_true, y_pred):
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D))
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1

    ind = np.transpose(np.asarray(linear_sum_assignment(w.max() - w)))
    return ind, w

def clustering_accuracy_score(y_true, y_pred):
    ind, w = hungray_aligment(y_true, y_pred)
    acc = sum([w[i, j] for i, j in ind]) / y_pred.size
    return acc

def clustering_score(y_true, y_pred):
    return {
        'NMI': round(normalized_mutual_info_score(y_true, y_pred)*100, 2),
        'ARI': round(adjusted_rand_score(y_true, y_pred)*100, 2),
        'ACC': round(clustering_accuracy_score(y_true, y_pred)*100, 2)
        }

def parse_output(output_text):
    # Define the regex pattern
    pattern = r"ID:\s*(\d+),\s*Utterance:\s*(.*),\s*Intent:\s*(.+)"
    # Find all matches in the output text
    matches = re.findall(pattern, output_text)
    parsed_output = []
    # Extract the information from the matches
    for i, match in enumerate(matches):
        id = int(match[0])
        # id = ids[i]
        utterance = match[1].strip()
        intent = match[2].strip()
        # Add the extracted information to the parsed output
        parsed_output.append({
            'id': id,
            'utterance': utterance,
            'intent': intent,
        })
    return parsed_output

def get_prompt_template():
    return """{prompt}

    CONTEXT EXAMPLES:
    {few_shot_examples}

    You should only respond in the format as described bellow:

    RESPONSE FORMAT:

    ID: <i>, Utterance: <content>, Intent: <intent>
    Where id is the same id as the one in the context example, content is the utterance and intent is the intent label.

    Based on the above guidelines, you should group the following utterances together and assign them an intent:

    TEST EXAMPLES:
    {test_examples}

    """

def get_embedder(embedder_name):
    if 'ada' in embedder_name:
        model = OpenAIEmbeddings()
    elif embedder_name == 'sentence-transformer':
        embedder_name = 'all-MiniLM-L6-v2'
        model = SentenceTransformer(embedder_name)
    return model

# TODO: Include all this in the class, defining embedder as an attribute
def _get_embeddings(model, text_list):
    # Check if model is instance of OpenAIEmbeddings
    if isinstance(model, OpenAIEmbeddings):
        embeddings = model.embed_documents(text_list)
    else:
        embeddings = model.encode(text_list)
    return embeddings

def get_embeddings(model, text_list, embeddings_path):
    # if 'true' in embeddings_path: # Here we will compute embeddings only for the unique labels, and map them to the text_list
    if os.path.exists(embeddings_path):
        with open(embeddings_path, 'rb') as handle:
            embedding_dict = pickle.load(handle)
    else:       
        unique_labels = list(set(text_list))
        embeddings = _get_embeddings(model, unique_labels)
        embedding_dict = {label: embedding for label, embedding in zip(unique_labels, embeddings)}
        with open(embeddings_path, 'wb') as handle:
            pickle.dump(embedding_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

    embeddings = [embedding_dict[label] for label in text_list]
    return embeddings

    # else: # Here we will compute embeddings for all intents
    #     if os.path.exists(embeddings_path):
    #         with open(embeddings_path, 'rb') as handle:
    #             embeddings = pickle.load(handle)
    #     else:
    #         embeddings = _get_embeddings(model, text_list)
    #         with open(embeddings_path, 'wb') as handle:
    #             pickle.dump(embeddings, handle, protocol=pickle.HIGHEST_PROTOCOL)
    #     return embeddings


def hash_dict(exp_dict):
    """Create a hash for an experiment. Credtts to github.com/haven-ai!

    Parameters
    ----------
    exp_dict : dict
        An experiment, which is a single set of hyper-parameters

    Returns
    -------
    hash_id: str
        A unique id defining the experiment
    """
    dict2hash = ""
    if not isinstance(exp_dict, dict):
        raise ValueError("exp_dict is not a dict")

    for k in sorted(exp_dict.keys()):
        if "." in k:
            raise ValueError(". has special purpose")
        elif isinstance(exp_dict[k], dict):
            v = hash_dict(exp_dict[k])
        elif isinstance(exp_dict[k], tuple):
            raise ValueError(f"{exp_dict[k]} tuples can't be hashed yet, consider converting tuples to lists")
        elif isinstance(exp_dict[k], list) and len(exp_dict[k]) and isinstance(exp_dict[k][0], dict):
            v_str = ""
            for e in exp_dict[k]:
                if isinstance(e, dict):
                    v_str += hash_dict(e)
                else:
                    raise ValueError("all have to be dicts")
            v = v_str
        else:
            v = exp_dict[k]

        dict2hash += str(k) + "/" + str(v)
    hash_id = hashlib.md5(dict2hash.encode()).hexdigest()

    return hash_id