from openai import OpenAI 
from dataclasses import dataclass
from openai import OpenAI
from ast import literal_eval
import torch.nn.functional as F
import torch
import concurrent.futures


# we will use some llm to construct the Y-space conditional to some X
@dataclass
class TopicDiscoveryPrompt:

	system_prompot = "You are a helpful assistant, and you output python code"

	user_prompt = """
	Your task is to come up with groups that partition the following list of strings at the semantic level.
	{string_list} 
	You must answer exclusively with a valid list of tuples, where element 0 is a category title (string), and element 1 is a very brief description of that category
	"""


@dataclass
class TopicAssignmentPrompt:

	system_prompot = "You are a helpful assistant, and you only output integers within the range specified by the number of topics"

	user_prompt = """
	Given a set of topics, their descriptions, and their IDs (0-indexed) and a new string, assign the string to the most appropriate topic.
	Reply only with the ID of the topics.
     
    Topics: 
    {topics}

    String to be classified:
    {string}

    topic ID:  
	"""


# NOTE: refactor this, put is somewhere else (it can be used by other modules)
def two_msg_call(
    openai_client: OpenAI,
    system_prompt: str,
    user_prompt: str,
    model: str = "gpt-4o-mini",
    return_logits: bool = False,
    top_logprobs: int = 20  # Number of top tokens to return if return_logits is True
) -> str:
    try:
        response = openai_client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
            logprobs=return_logits,  # Request log probabilities if return_logits is True
            top_logprobs=top_logprobs if return_logits else None
        )
        
        # Print token usage
        # print(f"Token usage - Input: {response.usage.prompt_tokens}, Output: {response.usage.completion_tokens}, Total: {response.usage.total_tokens}")
        
        output_text = response.choices[0].message.content
        
        if return_logits:
            logprobs = response.choices[0].logprobs
            return output_text, logprobs
        
    except Exception as e:
        print(e)
        return None
    
    return output_text


class LMAbstraction:

    def __init__(
            self, 
            openai_client: OpenAI, 
            model_name: str = "gpt-4o-mini", 
            max_workers: int = 16
        ):
        self.openai_client = openai_client
        self.max_workers = max_workers
        self.model_name = model_name

    @classmethod
    def from_topics(cls, topics: list[tuple[str, str]], openai_client: OpenAI, model_name: str = "gpt-4o-mini", max_workers: int = 16):
        self = cls(openai_client, model_name, max_workers)
        self.topics = topics
        # NOTE: this introduces subtle inconsistency, in that a string generated like so does not include some formatting (eg. quotes and newlines)
        # this shouldn't, but might, affect the model when classifiying a new string 
        self.topic_string = str(topics)
        # Initialize id_to_label and label_to_id mappings
        self.id_to_label = {idx: topic for idx, (topic, _) in enumerate(self.topics)}
        self.label_to_id = {topic: idx for idx, (topic, _) in enumerate(self.topics)}
        return self

    @property
    def topic_presentation_string(self):
        return '\n'.join(
            [f"{idx} - TOPIC: {topic}, \t DESCRIPTION: {desc}" 
            for idx, (topic, desc) in enumerate(self.topics)]
        )

    def fit(self, sentences: list[str], other_class: bool = True):
        '''
        Fit the model by discovering the relevant topics for a list of sentences.
        If `other_class`, add a category for sentences that do not fit any of the discovered topics.
        '''
        max_attempts = 3
        for _ in range(max_attempts):
            try:
                self.topic_string = two_msg_call(
                    self.openai_client, 
                    TopicDiscoveryPrompt.system_prompot,
                    TopicDiscoveryPrompt.user_prompt.format(string_list=sentences),
                    model=self.model_name
                )
                self.topic_string = self.topic_string.replace("```python", '').replace('```', '')
                self.topics = literal_eval(self.topic_string)
                break
            except Exception as e:
                print(e)
                print("Retrying...")

        if other_class:
            self.topics.append(('Other', 'Anything that doesn\'t fit the above topics'))
        # NOTE: like below, this introduces subtle inconsistency, 
        # in that a string generated like so does not include some formatting (eg. quotes and newlines)
        self.topic_string = str(self.topics)    
        self.id_to_label = {idx: topic for idx, (topic, _) in enumerate(self.topics)}
        self.label_to_id = {topic: idx for idx, (topic, _) in enumerate(self.topics)}

    def print_topics(self):
        assert hasattr(self, 'topics'), "You must fit the model first"
        for topic, desc in self.topics:
            print('---')
            print(f'Topic: {topic} \nDescription: {desc}')
    
    def _predict(self, sentence: str) -> torch.tensor: 
        y, logits = two_msg_call(
            self.openai_client, 
            TopicAssignmentPrompt.system_prompot,
            TopicAssignmentPrompt.user_prompt.format(
                topics=self.topic_presentation_string, string=sentence
            ),
            model=self.model_name,
            return_logits=True
        )

        tokens, logits = zip(
            *[(x['token'], x['logprob']) for x in logits.to_dict()['content'][0]['top_logprobs']]
        )

        # keep the logits only for the tokens that are associated with integers (ie. do not crash `int()`) 
        # TODO: should we assign these to the 'Other' category?
        def safe_int(token):
            try:
                return int(token)
            except ValueError:
                return None

        valid_pairs = [
            (number, logit) for token, logit in zip(tokens, logits) 
            if (number := safe_int(token)) is not None
        ] 
        topic_ids = [tid for tid, _ in valid_pairs]
        logits = [logit for _, logit in valid_pairs]
        
        # NOTE: the api only returns the logprobs for the top 5 tokens. Typically these will cover > 99% of the mass, but we probably want to check this
        covered_p_mass = torch.tensor(logits).exp().sum()
        if covered_p_mass < 0.95:
            print(f"Warning: only {covered_p_mass:.2f} of the probability mass is covered by the top 20 tokens")
 
        probs = F.softmax(torch.tensor(logits), dim=0) 
        
        # Create a zero vector for all topics
        all_topic_probs = torch.zeros(len(self.topics))
        
        # Add the probabilities for each topic ID (handling duplicates)
        for idx, (topic_id, prob) in enumerate(zip(topic_ids, probs)):
            if 0 <= topic_id < len(self.topics):  # Ensure topic_id is valid
                all_topic_probs[topic_id] += prob
            else:
                # since we assume that the model has a high recall (i.e. it will recognise strings of some topic when it sees them)
                # we add the mass associated with unseen indices to the 'other' category
                other_id = self.label_to_id['Other']
                all_topic_probs[other_id] += prob

        return all_topic_probs

    def _batch_predict(self, sentences: list[str]) -> torch.tensor:
        """
        Process multiple sentences in parallel for prediction.
        
        Args:
            sentences: List of sentences to predict topics for
            max_workers: Maximum number of parallel workers
            
        Returns:
            Tensor of probabilities for each sentence and topic
        """
        
        all_results = []
        with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            # Submit all prediction tasks
            future_to_sentence = {
                executor.submit(self._predict, sentence): i 
                for i, sentence in enumerate(sentences)
            }
            
            # Collect results as they complete
            for future in concurrent.futures.as_completed(future_to_sentence):
                result = future.result()
                all_results.append((future_to_sentence[future], result))
        
        # Sort results by original order and stack them
        all_results.sort(key=lambda x: x[0])
        return torch.stack([result for _, result in all_results])

    def predict(self, sentences: str | list[str]) -> torch.tensor:
        if isinstance(sentences, str):
            sentences = [sentences]

        assert isinstance(sentences, list) and isinstance(sentences[0], str), "Input must be a string or a list of strings"
        return self._batch_predict(sentences=sentences)

    
    def __call__(self, *args, **kwargs) -> torch.tensor:
        return self.predict(*args, **kwargs)


