import numpy as np
import pandas as pd
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from scipy.optimize import linear_sum_assignment
import re
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.preprocessing import LabelEncoder
from langchain.embeddings import OpenAIEmbeddings
from sentence_transformers import SentenceTransformer
import os
import pickle
import warnings
warnings.filterwarnings("ignore")

def hungray_aligment(y_true, y_pred):
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D))
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1

    ind = np.transpose(np.asarray(linear_sum_assignment(w.max() - w)))
    return ind, w

def clustering_accuracy_score(y_true, y_pred):
    ind, w = hungray_aligment(y_true, y_pred)
    acc = sum([w[i, j] for i, j in ind]) / y_pred.size
    return acc

def clustering_score(y_true, y_pred):
    return {
        'NMI': round(normalized_mutual_info_score(y_true, y_pred)*100, 2),
        'ARI': round(adjusted_rand_score(y_true, y_pred)*100, 2),
        'ACC': round(clustering_accuracy_score(y_true, y_pred)*100, 2)
        }

def extract_numbers(line):
    """
    Extracts numbers from a string and returns them as a list
    """
    numbers = []
    pattern = r'(\d+)-(\d+)|(\d+)'
    matches = re.findall(pattern, line)
    for match in matches:
        for group in match:
            if group:
                numbers.append(int(group))
    return numbers

def extract_intent(line):
    """
    Extracts the intent from a string and returns it
    """
    intent = None
    pattern = r'Intent:\s+(\w+)'
    match = re.search(pattern, line)
    if match:
        intent = match.group(1)
    return intent

def parse_example_string(example_str):
    """
    Parses the example string and returns a dictionary of intent to utterance numbers.
    This is used to parse an example response from the prompt for OpenAI's language model for our group intent classification task.
    """
    intent_dict = {}
    lines = example_str.split("\n")
    group_name = None
    for line in lines:
        if line.startswith("Group"):
            group_name = line.split(":")[0].strip()
        elif line.startswith("Intent"):
            intent_name = extract_intent(line)
            if group_name and intent_name:
                numbers = extract_numbers(lines[lines.index(line)-1])
                intent_dict[intent_name] = numbers
    return intent_dict

def number_utterances(df: pd.DataFrame, group) -> str:
    """
    Number the utterances in the list.
    This is used to format the data passed to the prompt for OpenAI's language model.
    We also need to keep track of the mapping between the utterance and the index
    Args: utterances (list): List of utterances
    Returns: str: Numbered utterances
    """
    i = 1
    input_numbered_utterances = ""
    utterances = df.iloc[group]["text"].tolist()
    number_to_idx = {} 
    for utterance in utterances:
        input_numbered_utterances += f"{i}. {utterance}\n"
        number_to_idx[i] = group[i-1]
        i += 1
    return input_numbered_utterances, number_to_idx


def parse_output(output_text):
    # Define the regex pattern
    pattern = r"ID:\s*(\d+),\s*Utterance:\s*(.*),\s*Intent:\s*(.+)"
    # Find all matches in the output text
    matches = re.findall(pattern, output_text)
    parsed_output = []
    # Extract the information from the matches
    for i, match in enumerate(matches):
        id = int(match[0])
        # id = ids[i]
        utterance = match[1].strip()
        intent = match[2].strip()
        # Add the extracted information to the parsed output
        parsed_output.append({
            'id': id,
            'utterance': utterance,
            'intent': intent,
        })
    return parsed_output

def get_prompt_template():
    return """{prompt}

    CONTEXT EXAMPLES:
    {few_shot_examples}

    You should only respond in the format as described bellow:

    RESPONSE FORMAT:

    ID: <i>, Utterance: <content>, Intent: <intent>
    Where id is the same id as the one in the context example, content is the utterance and intent is the intent label.

    Based on the above guidelines, you should group the following utterances together and assign them an intent:

    TEST EXAMPLES:
    {test_examples}

    """

def get_hard_few_shot_examples(args, data, batch=None):

    # Extract all hard samples
    # crea
    # Get hard examples
    few_shot_samples = []
    if args.smart_few_shot:
        for i in range(args.batch_size):
            print("Not implemented yet")
    else:
        for i in range(args.batch_size):
            sample = data.train_labeled_loader.dataset[i]
            utterance = data.tokenizer.decode(sample['input_ids'], skip_special_tokens = True)
            label = class_map_train[int(sample['labels'])]
            few_shot_samples.append(f"Example {i}: Utterance: {utterance}, Intent: {label}")
    return few_shot_samples

def get_few_shot_examples(args, data, few_shot_type, batch = None):
    class_map_train = {v: k for k, v in data.train_label_map.items()}

    if 'hard_samples' in few_shot_type:
        few_shot_examples = get_hard_few_shot_examples(args, data, batch)

    else:
        # Get examples
        few_shot_examples = []
        for i in range(N):
            sample = data.train_labeled_loader.dataset[i]
            utterance = data.tokenizer.decode(sample['input_ids'], skip_special_tokens = True)
            label = class_map_train[int(sample['labels'])]
            few_shot_examples.append(f"Example {i}: Utterance: {utterance}, Intent: {label}")
            
        # Format the few_shot_examples list
        few_shot_examples = '\n'.join(few_shot_examples) 

    return few_shot_examples
    
def get_test_samples(args, data, i, batch, b_size, verbose=False, class_map=None):
    test_examples = []        
    gt_labels = []          
    ids = [] 
    for j in range(b_size):
        id = i * args.batch_size + j
        ids.append(id)
        utterance = data.tokenizer.decode(batch['input_ids'][j], skip_special_tokens = True)
        label = class_map[int(batch['labels'][j])]
        test_examples.append(f"ID:{id}, {utterance}")
        if verbose:
            print(f"test example {id}: {utterance}")
        gt_labels.append(label)
    test_examples_formatted = '\n'.join(test_examples)

    return test_examples_formatted, gt_labels, ids

def _get_embeddings(model, text_list):
    # Check if model is instance of OpenAIEmbeddings
    if isinstance(model, OpenAIEmbeddings):
        embeddings = model.embed_documents(text_list)
    else:
        embeddings = model.encode(text_list)
    return embeddings

def get_embeddings(model, text_list, embeddings_path):
    # if 'true' in embeddings_path: # Here we will compute embeddings only for the unique labels, and map them to the text_list
    if os.path.exists(embeddings_path):
        with open(embeddings_path, 'rb') as handle:
            embedding_dict = pickle.load(handle)
    else:       
        unique_labels = list(set(text_list))
        embeddings = _get_embeddings(model, unique_labels)
        embedding_dict = {label: embedding for label, embedding in zip(unique_labels, embeddings)}
        with open(embeddings_path, 'wb') as handle:
            pickle.dump(embedding_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

    embeddings = [embedding_dict[label] for label in text_list]
    return embeddings

    # else: # Here we will compute embeddings for all intents
    #     if os.path.exists(embeddings_path):
    #         with open(embeddings_path, 'rb') as handle:
    #             embeddings = pickle.load(handle)
    #     else:
    #         embeddings = _get_embeddings(model, text_list)
    #         with open(embeddings_path, 'wb') as handle:
    #             pickle.dump(embeddings, handle, protocol=pickle.HIGHEST_PROTOCOL)
    #     return embeddings


def get_clusters(y_pred, y_true, results_path, intent_embeddings = 'text-embedding-ada-002', num_clusters = 150, use_embeddings_for_labels = True):
    if intent_embeddings:
        if 'ada' in intent_embeddings:
            model = OpenAIEmbeddings()
        elif intent_embeddings == 'sentence-transformer':
            intent_embeddings = 'all-MiniLM-L6-v2'
            model = SentenceTransformer(intent_embeddings)
        
        # Compute embeddings
        pred_embeddings_path = os.path.join(results_path, f'pred_embeddings_{intent_embeddings}.pickle')
        true_embeddings_path = os.path.join(os.path.dirname(results_path), f'true_embeddings_{intent_embeddings}.pickle')

        pred_embeddings = get_embeddings(model, y_pred, pred_embeddings_path)
        clustering_model_pred = KMeans(n_clusters=num_clusters)
        clustering_model_pred.fit(pred_embeddings)
        labels_pred = clustering_model_pred.labels_

        if use_embeddings_for_labels:
            true_embeddings = get_embeddings(model, y_true, true_embeddings_path)
            clustering_model_true = KMeans(n_clusters=num_clusters)
            clustering_model_true.fit(true_embeddings)
            labels_true = clustering_model_true.labels_
        else:
            label_encoder = LabelEncoder()
            labels_true = label_encoder.fit_transform(y_true)            

    else:
        label_encoder = LabelEncoder()
        labels_true = label_encoder.fit_transform(y_true)
        labels_pred = label_encoder.fit_transform(y_pred)
    return labels_pred, labels_true