import copy
import os
import time
from typing import Dict, List, Mapping

import openai
from openai import OpenAI
import numpy as np
import wandb
from transformers import GPT2TokenizerFast

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

"""
Helper methods for working with the OpenAI API
"""

def completion_with_retries(model: str, prompt: str, max_retries: int = 10, retry_interval_sec: int = 20, validation_fn=None, **kwargs) -> Mapping:
    # TODO: The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=os.getenv("OPENAI_API_ORG"))'
    # openai.organization = os.getenv("OPENAI_API_ORG")
    print("Using OpenAI API with key:", openai.api_key, " | org:", openai.organization)

    if model == 'davinci': # smaller context length; TODO: Handle this more neatly
        max_toks = 2048
        n_toks = num_tokens(prompt)
        if n_toks > max_toks - 1:
            print(f"Prompt has {n_toks} tokens, which is greater than the maximum {max_toks} tokens. Truncating prompt to {max_toks} tokens.")
            prompt = truncate_prompt(prompt, max_tokens=max_toks - 1, cut_out="middle")
            print("New prompt:", prompt)
            print("New prompt length:", len(prompt))

    for n_attempts_remaining in range(max_retries, 0, -1):
        if wandb.run is not None:
            wandb.log({"openai_api": 0.5}, step=int(time.time() * 1000))
        try:
            res = client.completions.create(model=model, prompt=prompt, **kwargs)
            if wandb.run is not None:
                wandb.log({"openai_api": 1}, step=int(time.time() * 1000))

            if validation_fn is not None:
                if not validation_fn(res):
                    print(f"VALIDATION FAILED!\nResponse:\n{res}\nTrying again.")
                    continue
            return res
        except (
            openai.RateLimitError,
            openai.ServiceUnavailableError,
            openai.APIError,
            openai.APIConnectionError,
            openai.Timeout,
            openai.TryAgain,
            openai.OpenAIError,
            ) as e:
            if wandb.run is not None:
                wandb.log({"openai_api": 0}, step=int(time.time() * 1000))
            print(e)
            print(f"Hit openai.error exception. Waiting {retry_interval_sec} seconds for retry... ({n_attempts_remaining - 1} attempts remaining)", flush=True)
            time.sleep(retry_interval_sec)
    return {}

def chat_completion_with_retries(model: str, sys_prompt: str, prompt: str, max_retries: int = 5, retry_interval_sec: int = 20, api_keys=None, validation_fn=None, **kwargs) -> Mapping:
    # TODO: The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=os.getenv("OPENAI_API_ORG"))'
    # openai.organization = os.getenv("OPENAI_API_ORG")
    print("Using OpenAI API with key:", openai.api_key, " | org:", openai.organization)

    # Truncate if necessary
    generation_tokens = kwargs.get('max_tokens', 1)
    max_toks = 4096 - generation_tokens
    n_sys_toks = num_tokens(sys_prompt)
    n_user_toks = num_tokens(prompt)
    print("n_sys_toks", n_sys_toks)
    print("n_user_toks", n_user_toks)
    if n_sys_toks + n_user_toks > max_toks:
        print(f"Truncating prompt from {n_sys_toks + n_user_toks} tokens to {max_toks} tokens")
        # We do truncation in a way that preserves the ratio of system and user tokens
        padding = 1 # TODO: figure out why this is necessary, in some cases `truncate_prompt` will return a prompt with an additional token
        new_sys_toks = int(max_toks * n_sys_toks / (n_sys_toks + n_user_toks)) - padding
        new_user_toks = max_toks - new_sys_toks - padding
        print("Truncating to new_sys_toks", new_sys_toks, "new_user_toks", new_user_toks)
        sys_prompt = truncate_prompt(sys_prompt, max_tokens=new_sys_toks, cut_out="middle")
        prompt = truncate_prompt(prompt, max_tokens=new_user_toks, cut_out="front")
        n_sys_toks = num_tokens(sys_prompt)
        n_user_toks = num_tokens(prompt)
        print("New prompts:")
        print(n_sys_toks, n_user_toks)
        print(sys_prompt)
        print(prompt)
        assert n_sys_toks + n_user_toks <= max_toks, f"Truncation failed: {n_sys_toks} + {n_user_toks} > {max_toks}"

    for n_attempts_remaining in range(max_retries, 0, -1):
        if wandb.run is not None:
            wandb.log({"openai_api": 0.5}, step=int(time.time() * 1000))
        try:
            res = client.chat.completions.create(model=model,
            messages=[
                    {"role": "system", "content": sys_prompt},
                    {"role": "user", "content": prompt},
                ],
            **kwargs)
            if wandb.run is not None:
                wandb.log({"openai_api": 1}, step=int(time.time() * 1000))

            if validation_fn is not None:
                if not validation_fn(res):
                    print(f"VALIDATION FAILED!\nResponse:\n{res}\nTrying again.")
                    continue
            return res
        except (
            openai.RateLimitError,
            openai.ServiceUnavailableError,
            openai.APIError,
            openai.APIConnectionError,
            openai.Timeout,
            openai.TryAgain,
            openai.OpenAIError,
            ) as e:
            if wandb.run is not None:
                wandb.log({"openai_api": 0}, step=int(time.time() * 1000))
            print(e)
            print(f"Hit openai.error exception. Waiting {retry_interval_sec} seconds for retry... ({n_attempts_remaining - 1} attempts remaining)", flush=True)
            time.sleep(retry_interval_sec)
    return {}


def logit_whitelist(labels: List[str]):
    # Set logit_bias to 100 for our two output classes to ensure the model only predicts these two options
    tokenized_labels = [tokenizer.encode(label)[0] for label in labels]
    logit_bias = {key: 100 for key in tokenized_labels}
    return logit_bias


def get_class_probs(res: Dict, class_labels: List[str]):
    # Extract logprobs for the output classes
    choice, top_logprobs = res.choices[0].text, res.choices[0].logprobs.top_logprobs[0]
    # Get logprobs for the output classes
    # OAI returns max 5 top_logprobs, so sometimes the label is not there and we need to fill it with negligible value
    logprobs = [top_logprobs[label] if label in top_logprobs else -100 for label in class_labels]

    # Convert logprobs into probs
    probs = [np.exp(lp) for lp in logprobs]
    assert np.isclose(np.sum(probs), 1) or len(class_labels) > len(top_logprobs), \
        f"Probabilities don't sum to 1: {probs} (class_labels: {class_labels}, top_logprobs: {top_logprobs}"
    # Make it exact with softmax
    probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
    return probs

def truncate_prompt(prompt: str, max_tokens: int, cut_out: str) -> str:
    """
    Truncate prompt to max_tokens, either from the front or the back.
    """
    toks = tokenizer.encode(prompt)
    if len(toks) > max_tokens:
        if cut_out == "front":
            toks = toks[-max_tokens:]
        elif cut_out == "middle":
            # Keep the front and back half of the prompt
            toks_at_each_end = max_tokens // 2
            toks = toks[:toks_at_each_end] + toks[-toks_at_each_end:]
        elif cut_out == "back":
            toks = toks[:max_tokens]
        else:
            raise ValueError(f"Invalid cut_out value: {cut_out}")
        prompt = tokenizer.decode(toks)
    return prompt

def exceeds_max_tokens(text: str, max_tokens: int) -> bool:
    toks = tokenizer.encode(text)
    exceeds = len(toks) > max_tokens
    if exceeds:
        print(
            f"WARNING: Maximum token length exceeded ({len(toks)} > {max_tokens})")
    return exceeds


def num_tokens(text: str):
    return len(tokenizer.encode(text))


def assemble_prompt(prompt_parts: List[str], query_text: str, max_tokens: int = 2049):
    """
    Return a prompt satisfying max_tokens, with as many prompt_parts as possible
    while making sure to include the query text.

    i.e full_prompt = prompt_parts[:as_many_as_fits] + [query_text]
    """
    instruction_parts = copy.deepcopy(prompt_parts)
    while True:
        full_prompt = "\n".join(instruction_parts + [query_text])
        # If exceeds max token length, pop off the last few-shot example
        if exceeds_max_tokens(full_prompt, max_tokens=max_tokens):
            instruction_parts.pop(-1)
        else:
            break
    return full_prompt
