import os
import time
import json
import requests
import numpy as np
with open('openai_api_key.txt', 'r') as file:
    # Load api key from file, adjusting for whitespace and brackets
    os.environ['OPENAI_API_KEY'] = file.read().strip().lstrip('[').rstrip(']').strip()
from openai import OpenAI, OpenAIError
client = OpenAI()
from utils.util import letters, save_json, bold

CHAT_COMPLETION_API_MODEL_NAMES = [
    "gpt-4-0125-preview", "gpt-4-1106-preview", "gpt-4-0613", "gpt-4-32k-0613",
    "gpt-3.5-turbo-0125", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-instruct",
    "gpt-3.5-turbo", "gpt-4", "gpt-4-32k", "gpt-4-turbo-preview",  # aliases (see below)
]

# Input/output costs in dollars per 1000 tokens
api_costs = {
    # GPT-3.5-Turbo Models: https://platform.openai.com/docs/models/gpt-3-5-turbo
    'gpt-3.5-turbo-0125':     (0.0005, 0.0015),  #  16k context
    'gpt-3.5-turbo-1106':     (0.001, 0.002),    #  16k context
    'gpt-3.5-turbo-instruct': (0.0015, 0.002),   #   4k context

    # GPT-4 Models: https://platform.openai.com/docs/models/gpt-4
    'gpt-4-0125-preview': (0.01, 0.03),          # 128k context
    'gpt-4-1106-preview': (0.01, 0.03),          # 128k context
    'gpt-4-0613':         (0.03, 0.06),          #   8k context
    'gpt-4-32k-0613':     (0.06, 0.12),          #  32k context
    'gpt-4o':             (0.005, 0.015),        # 128k context
}
aliases = {
    'gpt-3.5-turbo': 'gpt-3.5-turbo-0125',
    'gpt-4': 'gpt-4-0613',
    'gpt-4-32k': 'gpt-4-32k-0613',
    'gpt-4-turbo-preview': 'gpt-4-0125-preview',
    'gpt-4o-2024-05-13': 'gpt-4o'
}
# LEGACY: gpt-3.5-turbo-16k,   gpt-3.5-turbo-0613,   gpt-3.5-turbo-16k-0613

### FINE-TUNING FUNCTIONS ###

finetuning_costs = {
    'gpt-3.5-turbo-0125': (0.008, 0.003, 0.006),
    'gpt-4-0613': (0.09, 0.045, 0.09), # https://openai.com/gpt-4-ft-experimental-pricing/
}

def convert_job_to_dict(job):
    """
    Convert a job object to a dictionary using comprehension
    """
    return {key: str(value) for key, value in job.__dict__.items()}

def finetune_model(trainfile_path, model_name, valfile_path=None, return_json=False):
    trainfile_id = client.files.create(
        file=open(trainfile_path, "rb"),
        purpose="fine-tune"
    ).id

    if valfile_path:
        valfile_id = client.files.create(
            file=open(valfile_path, "rb"),
            purpose="fine-tune"
        ).id
    else:
        valfile_id = None
    
    ft_job = client.fine_tuning.jobs.create(
        training_file=trainfile_id,
        validation_file=valfile_id,
        model=model_name,
        hyperparameters={
            "n_epochs": 3,
            "batch_size": 1,
            "learning_rate_multiplier": 1
        }
    )

    if return_json:
        return convert_job_to_dict(ft_job)
    return ft_job

def retrieve_finetuning(id):
    return client.fine_tuning.jobs.retrieve(id)

def compute_finetuning_cost(ft_job, model_name):
    """
    Compute the training cost of a fine-tuning job in dollars
    ft_job: openai.FineTuningJob object
    model_name: string, model name
    returns: float, cost in dollars
    """
    model = aliases[model_name] if model_name in aliases else model_name
    if model in finetuning_costs:
        trained_tokens = ft_job['trained_tokens'] if type(ft_job) is dict else ft_job.trained_tokens
        return finetuning_costs[model][0] * trained_tokens / 1000, trained_tokens
    else:
        raise ValueError(f"Model name {model} not supported")
    
def get_finetuning_messages(examples):
    # Construct the messages for fine-tuning
    all_messages = []
    for example in examples:
        # Get the prompt text and response text
        prompt_text = example['prompt_text']
        response_text = example['response_text']

        # Append the messages
        all_messages.append({"messages": [{"role": "user", "content": prompt_text},
                                          {"role": "assistant", "content": response_text}]})
    return all_messages

def get_fine_tuning_checkpoints(job_id):
    """
    Fetches the checkpoints for a given fine-tuning job using the OpenAI API.

    Args:
        fine_tuning_job_id (str): The ID of the fine-tuning job.

    Returns:
        list or None: A list of checkpoint objects if successful, None otherwise.
    """
    # Example usage:
    # job_id = 'ftjob-eTqSsQzxUxGw1XoTkjEvhVcO'  # Replace with your actual fine-tuning job ID
    # checkpoints = get

    api_key = os.getenv('OPENAI_API_KEY')
    url = f"https://api.openai.com/v1/fine_tuning/jobs/{job_id}/checkpoints"
    headers = {
        'Authorization': f'Bearer {api_key}',
        'Content-Type': 'application/json'
    }
    response = requests.get(url, headers=headers)
    if response.status_code == 200:
        checkpoints_data = response.json()
        return checkpoints_data.get('data', None)  # Get the 'data' list if available
    else:
        print("Failed to retrieve checkpoints:", response.status_code, response.text)
        return None

def save_job_checkpoints(job_id, save_dir):
    checkpoints = get_fine_tuning_checkpoints(job_id)[::-1]
    for i, checkpoint in enumerate(checkpoints):
        save_json(checkpoint, f'{save_dir}ft_checkpoint_{i+1}.json')
        print(bold(f"Saved checkpoint {i+1} to:"), f'{save_dir}ft_checkpoint_{i+1}.json')
    
def retrieve_and_save_job(job_id, save_dir, save_checkpoints=True):
    ft_job = retrieve_finetuning(job_id)
    job_save_path = f'{save_dir}job_status_{ft_job.status}.json'
    save_json(convert_job_to_dict(ft_job), job_save_path)
    print(bold("Saved final job details to:"), job_save_path)
    if save_checkpoints:
        save_job_checkpoints(job_id, save_dir)

### OPENAI API FUNCTIONS ###

def construct_response_probabilities(logprobs_content_list):
    """
    Construct probabilities from a list of logprobs_contents objects
    logprobs_content: list of openai.ChatCompletionTokenLogprob objects for each token
    returns: list of dictionaries of {token: probability} pairs
             i.e. [
                    {token_A_1: prob_A_1, token_A_2: prob_A_2, ..., token_A_n: prob_A_n},
                    {token_B_1: prob_B_1, token_B_2: prob_B_2, ..., token_B_n: prob_B_n},
                    ...
                    {token_N_1: prob_N_1, token_N_2: prob_N_2, ..., token_N_n: prob_N_n}
                ]
    """
    if type(logprobs_content_list) is not list:
        raise ValueError("logprobs_content_list must be a list")
    token_probabilities_list = []
    for logprobs_content in logprobs_content_list:
        token_probs = construct_token_probabilities(logprobs_content)
        token_probabilities_list.append(token_probs)
    return token_probabilities_list

def construct_token_probabilities(logprobs_content):
    """
    Construct probabilities from a single logprobs_content object
    logprobs_content: ChatCompletionTokenLogprob object for a single token
    returns: dictionary of {token: probability} pairs
             i.e. {token_1: prob_1, token_2: prob_2, ..., token_n: prob_n}
    """
    top_logprobs = logprobs_content.top_logprobs
    token_probabilities = {
        top_logprob.token: np.exp(top_logprob.logprob)\
            for top_logprob in top_logprobs
    }
    return token_probabilities

def standardize_token_probs(token_prob_dict):
    standardized_prob_dict = {}
    for token, prob in token_prob_dict.items():
        # Standardize token: remove spaces, parentheses, newlines, convert to uppercase
        standardized_token = token.replace(' ', '').replace('(', '').replace(')', '').replace('\n', '').upper()
        if standardized_token in standardized_prob_dict:
            standardized_prob_dict[standardized_token] += prob
        else:
            standardized_prob_dict[standardized_token] = prob
    return standardized_prob_dict

def get_option_probabilities(prompt_text, options, llm):
    # Execute the query to the OpenAI API
    response, _, token_probabilities = robust_openai_query(prompt_text, llm, temperature=0.0, n_probs=20, max_tokens=1)
    prob_dict = standardize_token_probs(token_probabilities[0])  # remove spaces, parentheses, convert to uppercase
    answer_probs = {(letter): (prob_dict[letter] if letter in prob_dict else 0) for letter in letters[:len(options)]}
    return response, answer_probs

def get_text_from_response(response):
    """
    Get text from response object
    response: openai.ChatCompletion object
    """
    return response.choices[0].message.content

def get_openai_response(prompt_text, model_name="gpt-3.5-turbo",
                        temperature=0.0, n_probs=0, max_tokens=512):
    """
    Returns response from OpenAI MODEL_NAME model for a prompt_text
    task: string, task prepended to a question to get natural language explanation
    question: string, question to ask the model
    model_name: string, name of the model
    temperature: float, temperature for sampling
    n_probs: int, return probabilities as list of dictionaries
             i.e. a total of n_probs pairs of (token, token probability) for each output token
             (not implemented with older models)
    """
    messages = [{'role': 'user', 'content': (prompt_text).strip()}]
    logprobs = True if n_probs > 0 else False
    top_logprobs = n_probs if logprobs else None
    if model_name in CHAT_COMPLETION_API_MODEL_NAMES or (model_name.startswith("ft")):
        # Call the OpenAI API to generate a response
        response = client.chat.completions.create(model=model_name, messages=messages,
                                                #   logprobs=logprobs, top_logprobs = top_logprobs,
                                                  max_tokens=max_tokens, temperature=temperature)
        if logprobs:
            logprobs_content = response.choices[0].logprobs.content
            tokens = [content.token for content in logprobs_content]
            token_probabilities = construct_response_probabilities(logprobs_content)
            return response, tokens, token_probabilities
        return response
    else:
        raise ValueError(f"Model name {model_name} not supported")
    
def robust_openai_query(prompt_text, model_name="gpt-3.5-turbo",
                        temperature=0.0, n_probs=0, max_tokens=512):
    # Query (keep trying if we get an error)
    attempts = 0
    try_query = True
    while try_query:
        attempts += 1
        try:
            output = get_openai_response(prompt_text, model_name=model_name,
                                         temperature=temperature, n_probs=n_probs, max_tokens=max_tokens)
            try_query = False
        except OpenAIError as e:
            print("ERROR! I'm going to sleep for " + str(attempts) + "s")
            print("Error:", str(e))
            time.sleep(attempts+2)
    return output

def compute_cost(response, model_name, finetuned_model=False):
    """
    Compute the cost of a response in dollars
    response: string, response from the model
    returns: float, cost in dollars
    """
    prompt_tokens, output_tokens = response.usage.prompt_tokens, response.usage.completion_tokens
    model_name = model_name.split(':')[1] if model_name.startswith("ft") else model_name
    model = aliases[model_name] if model_name in aliases else model_name
    if model in api_costs:
        input_cost, output_cost = api_costs[model] if not finetuned_model else finetuning_costs[model][1:]
        return input_cost * prompt_tokens / 1000, output_cost * output_tokens / 1000
    else:
        raise ValueError(f"Model name {model} not supported")
    

def update_and_save_costs(costs_dict, responses, query_type, model_name, save_dir):
    """
    Update costs dictionary with new responses and save to json file
    costs_dict: dictionary, initial costs dictionary
    responses: list of responses from the model
    query_type: string, type of query (cot_explanation or faithfulness)
    model_name: string, name of the model
    save_dir: string, path to the save directory
    returns: dictionary, updated costs dictionary
    """
    responses = list(responses) if type(responses) is not list else responses
    for response in responses:
        input_cost, output_cost = compute_cost(response, model_name,
                                               finetuned_model=model_name.startswith("ft:"))
        for time_type in ['session', 'all time']:
            costs_dict[time_type][query_type]['input'] += input_cost
            costs_dict[time_type][query_type]['output'] += output_cost
            costs_dict[time_type][query_type]['total'] += input_cost + output_cost
    with open(save_dir+'costs.json', 'w') as f:
        json.dump(costs_dict, f, indent=4)
    return costs_dict

def initialize_costs(save_dir):
    """
    Initialize costs.json if it does not exist, otherwise load it
    save_dir: string, path to the save directory
    returns: dictionary, costs dictionary with 3 layers of keys:
    1. session/all time: stores cost for the current session or all time, respectively
    2. cot_explanation/faithfulness: stores cost for CoT explanation or faithfulness, respectively
    3. input/output/total: stores input, output, and total costs, respectively
    """
    costs_file = os.path.join(save_dir, 'costs.json')
    if os.path.exists(costs_file):
        with open(costs_file, 'r') as f:
            costs_dict = json.load(f)
        costs_dict['session'] = {query_type: {'input': 0, 'output': 0, 'total': 0}\
                            for query_type in ['cot_explanation', 'faithfulness']}
    else:
        costs_dict = {time_type: {query_type: {'input': 0, 'output': 0, 'total': 0}\
                        for query_type in ['cot_explanation', 'faithfulness']}\
                            for time_type in ['session', 'all time']}
    return costs_dict