import torch
from transformers import AutoTokenizer
from utils.util import extract_first_sentences, extract_sentences
from utils.loggers import loggers
from openai import OpenAI
import time
import random

class ProposalModel():
    '''
    This class implements the proposal model used to generate completions during search process
    '''
    def __init__(self, config):
        self.config = config
        port = random.choice(config['port'])

        self.api_key = "proposal"
        self.api_base = f"http://localhost:{port}/v1"

        self.client = OpenAI(
            api_key=self.api_key,
            base_url=self.api_base
        )

        self.model = config['proposal']
        print(f'Proposal Model is {self.model}')
        model_lists = [data.id for data in self.client.models.list().data]
        assert self.model in model_lists, f"Model {self.model} not found in {model_lists}"

    def switch_port(self):
        port = random.choice(self.config['port'])

        self.api_base = f"http://localhost:{port}/v1"

        self.client = OpenAI(
            api_key=self.api_key,
            base_url=self.api_base
        )
        self.model = self.config['proposal']
        model_lists = [data.id for data in self.client.models.list().data]
        assert self.model in model_lists, f"Model {self.model} not found in {model_lists}"

    def query(self, 
            prompt, 
            partial_steps,
            temp=None,
            n=None,
            stop=None,
            max_tokens=None,
        ):

        message_chat = [{"role": "user", "content": prompt + partial_steps}] if self.config['gsm8k'] or not partial_steps else [{"role": "user", "content": prompt}, {"role": "special", "content": partial_steps}]
        add_generation_prompt = True if self.config['gsm8k'] or not partial_steps else False
        num_trials = 0
        max_trials = 3
        while num_trials < max_trials:
            try:
                responses = self.client.chat.completions.create(
                        model=self.model,
                        messages=message_chat,
                        max_tokens=self.config['max_tokens'] if max_tokens is None else max_tokens,
                        n=self.config['num_candidates'] if n is None else n,
                        temperature=self.config['temperature'] if temp is None else temp,
                        frequency_penalty=self.config['frequency_penalty'], 
                        presence_penalty=self.config['presence_penalty'],
                        extra_body={
                            "add_generation_prompt": add_generation_prompt,
                            "add_special_tokens": True
                        }
                    )
                contents = [choice.message.content.strip() for choice in responses.choices]
                finish_reasons = [choice.finish_reason for choice in responses.choices]
                if len(contents) == 0:
                    raise RuntimeError("No response from the model")
                return contents, finish_reasons
            except Exception as e:
                num_trials += 1
                print(e)
                loggers["error"].info(f"Error in query: {e}")
                if num_trials > 3:
                    print(f"Retry exceed the max_retries {num_trials} times.", flush=True)
                    break
                time.sleep(10)

    def get_response(
                self,
                prompt, 
                partial_steps='',
                temp=None,
                n=None, 
                stop=None,
                max_tokens=None,
                extract_first_sentence=True,
                search_step=None
            ):
        self.switch_port()
        try: 
            contents, finish_reasons = self.query(prompt, partial_steps, temp=temp, n=n, stop=stop, max_tokens=max_tokens)
            # breakpoint()
            if extract_first_sentence:
                # for s1k series tasks, we currently extracg first serveral sentences with tokens more than ensure_length to avoid too many steps
                ensure_length = self.config.get('ensure_length', 2048)
                output =  extract_first_sentences(contents, search_step, gsm8k=self.config['gsm8k'], finish_reasons=finish_reasons, ensure_length=ensure_length)
                return output
            
            else: 
                if search_step is not None:
                    return extract_sentences([a.strip() for a in contents], search_step=search_step)
                else:
                    return [a.strip() for a in contents]
        except Exception as e:
            return '<SKIP>'