import requests
import json
import torch as th
import random
import copy, heapq
import numpy as np
import time
from utils import api_key_list

class SentenceEncoder:
    def __init__(self,
                 model_name: str = 'ChatGPT',
                 repo_id: str = 'all-mpnet-base-v2',
                 ) -> None:
        self.model_name = model_name
        if self.model_name == 'ChatGPT':
            self.url = "https://api.openai.com/v1/chat/completions"
            self.headers = {
                "Content-Type": "application/json",
                "Authorization": f"{api_key_list[0]}"
        }
        else:
            raise NotImplementedError
    

    def get_embedding(self, input):
        data = {

            "model": "text-embedding-ada-002",
            "input": input
        }
        url = "https://api.openai.com/v1/embeddings"
        res = None
        try:
            response = requests.post(url, headers=self.headers, data=json.dumps(data).encode('utf-8'))
            res = response.content.decode("utf-8")
            res = json.loads(res)["data"][0]["embedding"]
        except Exception as e:
            time.sleep(10)
            return self.get_embedding(input)
        return res


    def encode(self, sentence: str) -> th.Tensor:
        if self.model_name == 'ChatGPT':
            raw_embedding = self.get_embedding(sentence)
        elif self.model_name == 'Sentence-Transformers':
            raw_embedding = self.model.encode(sentence)
        else:
            raise NotImplementedError

        embedding = th.as_tensor(raw_embedding)

        return embedding

class base_query_strategy(object):

    def __init__(self, k=3):
        self.k = k

    def customized_query(self, original_problem, question_list)->list:
        return np.random.choice(question_list, self.k, replace=False).tolist()

    def query(self, original_problem, question_list):
        if len(question_list) <= self.k:
            return question_list
        else:
            selected_questions = self.customized_query(original_problem, question_list)

            return selected_questions


class random_query(base_query_strategy):
    pass


from sklearn.cluster import KMeans
class diversity_based_query(base_query_strategy):
    """
        Use k-means clustering to select different question to query.
    """

    def __init__(self, k=3):
        super(diversity_based_query, self).__init__(k)
        from utils import api_key_list
        self.encoder = SentenceEncoder()


    def customized_query(self, original_problem, question_list) ->list:
        question_embedding = np.array([self.encoder.get_embedding(question) for question in question_list])

        kmeans = KMeans(n_clusters=self.k, random_state=0).fit(question_embedding)
        labels = kmeans.labels_

        unique_labels = np.unique(labels)
        random_indices = []

        for label in unique_labels:
            indices = np.where(labels == label)[0]
            random_index = np.random.choice(indices)
            random_indices.append(random_index)

        query_list = [question_list[i] for i in random_indices]

        return query_list


class similarity_based_query(base_query_strategy):

    def __init__(self, k=3):
        super(similarity_based_query, self).__init__(k)
        from utils import api_key_list
        self.encoder = SentenceEncoder()


    def compute_similarity(self, target, current):
        target, current = np.array(target), np.array(current)

        return np.linalg.norm(target - current)

    def get_smaller_from_list(self, my_list: list, n=1):
        my_list = copy.deepcopy(my_list)
        smaller_number: list = heapq.nsmallest(n, my_list)
        smaller_index = []
        for _ in smaller_number:
            index = my_list.index(_)
            smaller_index.append(index)
            my_list[index] = float("inf")
        return smaller_number, smaller_index


    def customized_query(self, original_problem, question_list)->list:
        topic_embedding = self.encoder.get_embedding(original_problem)

        question_embedding = [self.encoder.get_embedding(question) for question in question_list]
        similarity = [self.compute_similarity(topic_embedding, que_emb) for que_emb in question_embedding]

        _, idx = self.get_smaller_from_list(similarity, n=self.k)

        query_list = [question_list[i] for i in idx]

        return query_list

