import openai
from openai import OpenAI
import anthropic
import os
import time
import torch
import gc
import tiktoken
from typing import Dict, List


class GPT:
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 0.5
    API_MAX_RETRY = 5
    API_TIMEOUT = 20
    API_LOGPROBS = True
    API_TOP_LOGPROBS = 20

    def __init__(self, model_name):
        self.model_name = model_name
        if 'gpt' in model_name:
            self.client = OpenAI(api_key="")
        elif 'together' in model_name:
            TOGETHER_API_KEY = os.environ.get("")
            self.client = OpenAI(api_key=TOGETHER_API_KEY)
        else:
            raise ValueError(f"Unknown model name: {model_name}")
        self.tokenizer = tiktoken.encoding_for_model("gpt-4")  
        self.tokenizer.vocab_size = 100256

    def generate(
        self, convs: List[List[Dict]], max_n_tokens: int, temperature: float, top_p: float
    ):
        outputs = []
        for conv in convs:
            output = self.API_ERROR_OUTPUT
            for _ in range(self.API_MAX_RETRY):
                try:
                    response = self.client.chat.completions.create(
                        model=self.model_name,
                        messages=conv,
                        max_tokens=max_n_tokens,
                        temperature=temperature,
                        top_p=top_p,
                        timeout=self.API_TIMEOUT,
                        logprobs=self.API_LOGPROBS,
                        top_logprobs=self.API_TOP_LOGPROBS,
                        seed=0,
                    )
                    response_logprobs = [
                        dict((response.choices[0].logprobs.content[i_token].top_logprobs[i_top_logprob].token, 
                                response.choices[0].logprobs.content[i_token].top_logprobs[i_top_logprob].logprob) 
                                for i_top_logprob in range(len(response.choices[0].logprobs.content[i_token].top_logprobs))
                        )
                        for i_token in range(len(response.choices[0].logprobs.content))
                    ]

                    output = {'text': response.choices[0].message.content,
                            'logprobs': response_logprobs,
                            'n_input_tokens': response.usage.prompt_tokens,
                            'n_output_tokens': response.usage.completion_tokens,
                    }
                    break
                except openai.OpenAIError as e:
                    print(type(e), e)
                    time.sleep(self.API_RETRY_SLEEP)

                time.sleep(self.API_QUERY_SLEEP)
            outputs.append(output)
        return outputs


class HuggingFace:
    def __init__(self,model_name, model, tokenizer):
        self.model_name = model_name
        self.model = model 
        self.tokenizer = tokenizer
        self.pos_to_token_dict = {v: k.replace('▁', ' ') for k, v in self.tokenizer.get_vocab().items()}
        self.eos_token_ids = [self.tokenizer.eos_token_id]
        if 'llama3' in self.model_name.lower():
            self.eos_token_ids.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
        self.pad_token_id = self.tokenizer.pad_token_id

    def generate(self, 
                 full_prompts_list: List[str],
                 max_n_tokens: int, 
                 temperature: float,
                 top_p: float = 1.0) -> List[Dict]:
        if 'llama2' in self.model_name.lower():
            max_n_tokens += 1
        batch_size = len(full_prompts_list)
        vocab_size = len(self.tokenizer.get_vocab())
        inputs = self.tokenizer(full_prompts_list, return_tensors='pt', padding=True)
        inputs = {k: v.to(self.model.device.index) for k, v in inputs.items()}
        input_ids = inputs["input_ids"]

        output = self.model.generate(
            **inputs,
            max_new_tokens=max_n_tokens,  
            do_sample=False if temperature == 0 else True,
            temperature=None if temperature == 0 else temperature,
            eos_token_id=self.eos_token_ids,
            pad_token_id=self.tokenizer.pad_token_id,
            top_p=top_p,
            output_scores=True,
            return_dict_in_generate=True,
        )
        output_ids = output.sequences
        if not self.model.config.is_encoder_decoder:
            output_ids = output_ids[:, input_ids.shape[1]:]  
        if 'llama2' in self.model_name.lower():
            output_ids = output_ids[:, 1:]

        generated_texts = self.tokenizer.batch_decode(output_ids)
        logprobs_tokens = [torch.nn.functional.log_softmax(output.scores[i_out_token], dim=-1).cpu().numpy() 
                           for i_out_token in range(len(output.scores))]
        if 'llama2' in self.model_name.lower():
            logprobs_tokens = logprobs_tokens[1:]

        logprob_dicts = [[{self.pos_to_token_dict[i_vocab]: logprobs_tokens[i_out_token][i_batch][i_vocab]
                         for i_vocab in range(vocab_size)} 
                         for i_out_token in range(len(logprobs_tokens))
                        ] for i_batch in range(batch_size)]

        outputs = [{'text': generated_texts[i_batch],
                    'logprobs': logprob_dicts[i_batch],
                    'n_input_tokens': len(input_ids[i_batch][input_ids[i_batch] != 0]),
                    'n_output_tokens': len(output_ids[i_batch]),
                   } for i_batch in range(batch_size)
        ]

        for key in inputs:
            inputs[key].to('cpu')
        output_ids.to('cpu')
        del inputs, output_ids
        gc.collect()
        torch.cuda.empty_cache()

        return outputs

    