from typing import List, Any
import os
import gc
import json
import sys
import time
import argparse
import random
import copy
import warnings
import pickle as pc
from tqdm import tqdm
from glob import glob

import torch
import openai

from config import ex
from utils import fixed_seed, OPENAI_MODEL_NAMES, OPENAI_CHAT_MODEL_NAMES
from load_dataset import prepare_task2_prompt_dataset


def call_gpt3(
    prompt: str,
    model_name: str = None, 
    max_len: int = 128, 
    temp: float = 0.0, 
    num_log_probs: int = 5, 
    echo: bool = False,
    n: int = 1,
    stop: Any = '\n',
    freq_penalty: float = 0.0,
    pres_penalty: float = 0.0,
    top_p: float = 1.0,
    best_of: int = 1
):
    # call GPT-3 API until result is provided and then return it
    response = None
    received = False
    
    openai.api_key = ""
    
    while not received:
        try:
            response = openai.Completion.create(
                model=model_name,
                prompt=prompt,
                max_tokens=max_len,
                temperature=temp,
                n=n,
                logprobs=num_log_probs,
                echo=echo,
                stop=stop,
                presence_penalty=pres_penalty,
                frequency_penalty=freq_penalty,
                top_p=top_p,
                best_of=best_of
            )
            received = True
        except openai.error.OpenAIError as e:
            print(f"OpenAIError: {e}.")
            error = sys.exc_info()[0]
            if error == openai.error.InvalidRequestError:
                # something is wrong: e.g., prompt too long
                print(f"InvalidRequestError\nPrompt passed in:\n\n{prompt}\n\n")
                assert False
            
            time.sleep(2)
    
    return response

def call_chat_gpt3(
    messages: List[dict],
    model_name: str = None, 
    max_len: int = 128, 
    temp: float = 0.0, 
    n: int = 1,
    stop: Any = '\n',
    freq_penalty: float = 0.0,
    pres_penalty: float = 0.0,
    top_p: float = 1.0,
):
    # call GPT-3 API until result is provided and then return it
    response = None
    received = False
    
    openai.api_key = ""
    
    while not received:
        try:
            response = openai.ChatCompletion.create(
                model=model_name,
                messages=messages,
                max_tokens=max_len,
                temperature=temp,
                n=n,
                stop=stop,
                presence_penalty=pres_penalty,
                frequency_penalty=freq_penalty,
                top_p=top_p,
            )
            received = True
        except openai.error.OpenAIError as e:
            print(f"OpenAIError: {e}.")
            error = sys.exc_info()[0]
            if error == openai.error.InvalidRequestError:
                # something is wrong: e.g., prompt too long
                print(f"InvalidRequestError\nPrompt passed in:\n\n{messages}\n\n")
                assert False
            
            time.sleep(2)
    
    return response

def make_chat_format(instance, rounding_step=1):
    prompt = instance["task2_prompt_input"]
    message = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt}
    ]
    if rounding_step == 1:
        return message
    else:
        for i in range(1, rounding_step):
            prev_resp = instance[f"{i}_task1_openai_resp"]

            message.append({"role": "system", "content": prev_resp})
            message.append({"role": "user", "content": "continue"})
        
        return message

def openai_chat_inference(dataset, config, result_save_dir):
    total_tokens = 0
    prompt_tokens = 0
    completion_tokens = 0
    results = []
    for instance in tqdm(dataset, total=len(dataset)):
        prompt = make_chat_format(instance, config["rounding_step"])
        
        resp = call_chat_gpt3(
            prompt,
            model_name=config["model_name"],
            max_len=config['max_len'],
            temp=config['temp'],
            n=config['n'],
            stop=config['stop'],
            freq_penalty=config['freq_penalty'],
            pres_penalty=config['pres_penalty'],
            top_p=config['top_p']
        )
        total_tokens += resp.usage["total_tokens"]
        prompt_tokens += resp.usage["prompt_tokens"]
        completion_tokens += resp.usage["completion_tokens"]

        new_instance = copy.deepcopy(instance)
        new_instance["{}_task2_openai_resp".format(config["rounding_step"])] = resp.choices[0].message["content"]
        results.append(new_instance)

    print("# of results:", len(results))
    
    if config["model_name"] == 'gpt-4':
        estimated_cost = (prompt_tokens/1000)*0.03 + (completion_tokens/1000)*0.06
        estimated_cost_won = estimated_cost * 1304
    elif config["model_name"] == 'gpt-3.5-turbo':
        estimated_cost = (total_tokens/1000)*0.002
        estimated_cost_won = estimated_cost * 1304

    with open(os.path.join(result_save_dir, "{}_{}_generation.json".format(config["rounding_step"], config["template_type"])), "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent="\t")

    with open(os.path.join(result_save_dir, "{}_{}_usage.txt".format(config["rounding_step"], config["template_type"])), "w", encoding="utf-8") as f:
        f.write(f"Total number of tokens: {total_tokens}, {prompt_tokens}, {completion_tokens}\n")
        f.write(f"Estimated Costs: ${estimated_cost} >> {estimated_cost_won} won")


def openai_inference(dataset, config, result_save_dir):
    batch_size = config["batch_size"]
    batch_num = int(len(dataset)/batch_size)
    if len(dataset)%batch_size != 0:
        batch_num += 1
    print(len(dataset), batch_num, batch_size)

    total_tokens = 0
    results = []
    for batch_idx in tqdm(range(batch_num)):
        batch = dataset[batch_idx*batch_size:(batch_idx+1)*batch_size]
        
        prompts = [ele['task2_prompt_input'] for ele in batch]
        
        resp = call_gpt3(
            prompts,
            model_name=config["model_name"],
            max_len=config['max_len'],
            temp=config['temp'],
            num_log_probs=config['num_log_probs'],
            echo=config['echo'],
            n=config['n'],
            stop=config['stop'],
            freq_penalty=config['freq_penalty'],
            pres_penalty=config['pres_penalty'],
            top_p=config['top_p']
        )
        
        total_tokens += resp.usage["total_tokens"]

        for choice in resp.choices:
            new_instance = copy.deepcopy(batch[choice.index])
            new_instance["task2_openai_resp"] = choice.text
            
            results.append(new_instance)
    
    print("# of results:", len(results))
    
    estimated_cost = (total_tokens/1000)*0.02
    estimated_cost_won = estimated_cost * 1304

    with open(os.path.join(result_save_dir, "{}_generation.json".format(config["template_type"])), "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent="\t")

    with open(os.path.join(result_save_dir, "{}_usage.txt".format(config["template_type"])), "w", encoding="utf-8") as f:
        f.write(f"Total number of tokens: {total_tokens}\n")
        f.write(f"Estimated Costs: ${estimated_cost} >> {estimated_cost_won} won")


@ex.automain
def main(_config):
    _config = copy.deepcopy(_config)

    fixed_seed(_config["seed"])

    # load generations from task 1
    if 'text' in _config["model_name"]:
        task1_result_dir = os.path.join('./reports', _config["file_version"], _config["model_name"], "image-sharing-turn-prediction-v8_results.json")
    else:
        task1_result_dir = os.path.join('./reports', _config["file_version"], _config["model_name"], str(-1), "{}_image-sharing-turn-prediction-v8_results.json".format(_config["rounding_step"]))
    
    with open(task1_result_dir, 'r') as f:
        task1_results = json.load(f)
    # prepare prompt dataset
    prompt_dataset = prepare_task2_prompt_dataset(_config["model_name"], _config["rounding_step"], task1_results, _config["template_type"])
    

    if _config["sampled_test"]:
        random.shuffle(prompt_dataset)
        prompt_dataset = prompt_dataset[:_config["sampled_num"]]
    
    result_save_dir = os.path.join(
        _config["log_dir"], 
        _config["file_version"], 
        _config["model_name"],
         _config['datatype'], 
         str(_config["seed"]),
    )
    os.makedirs(result_save_dir, exist_ok=True)
    
    if _config["model_name"] in OPENAI_MODEL_NAMES:
        openai_inference(prompt_dataset, _config, result_save_dir)
    elif _config["model_name"] in OPENAI_CHAT_MODEL_NAMES:
        openai_chat_inference(prompt_dataset, _config, result_save_dir)
    else:
        raise ValueError("wrong model name!")

