import time
from openai import AzureOpenAI
import json, os

def get_rating_from_azure_chat_call(model, message, max_tokens):

    api_key = "redacted"
    api_base= "redacted"
    client = AzureOpenAI(api_key = api_key,  
                            api_version = "2023-05-15",
                            azure_endpoint = api_base)
    full_prompt = [{'role': 'user', 'content': message}]

    response = None
    wait_time = 5
    while response is None:
        try:
            response = client.chat.completions.create(
                model=model,
                messages=full_prompt,
                max_tokens=max_tokens,
                n=1
            )
            r = response.choices[0]
            if r.finish_reason == "content_filter":
                print("Content filter triggered, trying again")
                response = None
            
            content = r.message.content.strip().lower()
            content = content.replace("rating:", "").strip()

            if not content.lstrip('-').isdigit():
                print(f"Response '{content}' is not a valid integer, trying again")
                response = None
            
        except Exception as e:
            print(f'Caught exception {e}.')
            print(f'Waiting {wait_time} seconds.')
            time.sleep(wait_time)

    return int(content)


def get_index_from_azure_chat_call(model, message, max_tokens):
    api_key = "redacted"
    api_base = "redacted"
    client = AzureOpenAI(api_key = api_key,  
                            api_version = "2023-05-15",
                            azure_endpoint = api_base)
    full_prompt = [{'role': 'user', 'content': message}]

    response = None
    wait_time = 5
    while response is None:
        try:
            response = client.chat.completions.create(
                model=model,
                messages=full_prompt,
                max_tokens=max_tokens,
                n=1
            )
            r = response.choices[0]
            if r.finish_reason == "content_filter":
                print("Content filter triggered, trying again")
                response = None
            
            content = r.message.content.strip().lower()
            if content.isdigit():
                result = int(content)
            else:
                print(f"Response '{content}' is not a valid integer, trying again")
                response = None
            
        except Exception as e:
            print(f'Caught exception {e}.')
            print(f'Waiting {wait_time} seconds.')
            time.sleep(wait_time)

    return result


def int_to_word(num):
    number_words = {
        2: 'two',
        3: 'three',
        4: 'four',
        5: 'five',
        6: 'six',
        7: 'seven',
        8: 'eight',
        9: 'nine',
        10: 'ten'
    }
    
    if num in number_words:
        return number_words[num]
    else:
        return "Number out of range"
    

def get_response_from_azure_chat_call(model, message, max_tokens):
    api_key = "redacted"
    api_base= "redacted"
    client = AzureOpenAI(api_key = api_key,  
                            api_version = "2023-05-15",
                            azure_endpoint = api_base)
    full_prompt = [{'role': 'user', 'content': message}]

    response = None
    wait_time = 5
    while response is None:
        try:
            response = client.chat.completions.create(
                model=model,
                messages=full_prompt,
                max_tokens=max_tokens,
                n=1
            )
            r = response.choices[0]
            if r.finish_reason == "content_filter":
                print("Content filter triggered, trying again")
                response = None

            elif r.finish_reason == "length":
                if max_tokens < 4096:
                    print(f"Not enough tokens, trying again with {max(int(max_tokens * 1.5), 4096)} tokens")
                    return get_response_from_azure_chat_call(model, message, max(int(max_tokens * 1.5), 4096))
                else:
                    print("4096 token limit reached, using what we have")
                    return r.message.content.strip()
            
            content = r.message.content.strip()

        except Exception as e:
            print(f'Caught exception {e}.')
            print(f'Waiting {wait_time} seconds.')
            time.sleep(wait_time)

    return content


def parse_response(args, response, type):
    if args.verbose:
        print(f"type: {type}, Response:\n{response}\n")
    if type == "zero_shot":
        processed_response = response.lower().strip()
        if processed_response[:10] == "apartment ":
            processed_response = processed_response.lstrip("apartment ").strip()
        if processed_response.isdigit():
            result = int(processed_response) - 1  # Decrement here for zero_shot
        else:
            result = parse_response(args, response, "cot")  # No need to decrement, it will be done in the recursive call
    
    elif type == "cot":
        prompt = f"The following is a response to a question about which apartment is most desirable to a tenant. Determine the answer in the response. Possible valid answers are numbers 1 through 4. In all other cases, respond with 0. Respond with only the number, do not include anything else.\n\nResponse: {response}"
        result = get_index_from_azure_chat_call(args.model, prompt, 10) - 1  # Decrement here for cot

    return result


def parse_method_response(args, response, prompt_method):
    if args.verbose:
        print(f"method: {prompt_method}, Response:\n{response}\n")
    if prompt_method == "distractor_nback":
        length = len(response)
        response = response.split("true]")[-1]
        response = response.split("false]")[-1]
        response = response.split("True]")[-1]
        response = response.split("False]")[-1]
        response = response.split("True']")[-1]
        response = response.split("False']")[-1]
        response = response.split('True"]')[-1]
        response = response.split('False"]')[-1]
        if len(response) == length: raise Exception(f"parsing failed: {response}")
        result = parse_response(args, response, "zero_shot")
    
    elif prompt_method == "very_carefully_think" or prompt_method == "cot_time_limit" or prompt_method == "cot" or prompt_method == "no_ratings_cot":
        result = parse_response(args, response, "cot")
    
    elif prompt_method == "zero_shot" or prompt_method == "no_ratings_zero_shot":
        result = parse_response(args, response, "zero_shot")
    
    else:
        raise ValueError(f"Prompt method {prompt_method} not recognized")
    
    return result


def reparse(args):
    for i in range(40):
        path = f"apartments/results/gpt-4o/zero_shot_cot/{i}.json"
        with open(path, "r") as f:
            data = json.load(f)
        zero_shot_response = data["zero_shot_response"]
        print(zero_shot_response)
        result = parse_response(args, zero_shot_response, "zero_shot")
        if result != data["parsed_zero_shot_response"]:
            print(f"Mismatch corrected for {i}")
        
        data["parsed_zero_shot_response"] = result
        with open(path, "w") as f:
            json.dump(data, f)


def separate_zero_shot_cot():
    if not os.path.isdir(f"apartments/results/gpt-4o/zero_shot/"):
        os.makedirs(f"apartments/results/gpt-4o/zero_shot/")
    if not os.path.isdir(f"apartments/results/gpt-4o/cot/"):
        os.makedirs(f"apartments/results/gpt-4o/cot/")
    
    for i in range(40):
        path = f"apartments/results/gpt-4o/zero_shot_cot/{i}.json"
        with open(path, "r") as f:
            data = json.load(f)
        
        zero_shot_data = {}
        cot_data = {}
        for key in data:
            if "zero_shot" in key:
                zero_shot_data[key] = data[key]
            elif "cot" in key:
                cot_data[key] = data[key]
            else:
                zero_shot_data[key] = data[key]
                cot_data[key] = data[key]
        
        with open(f"apartments/results/gpt-4o/zero_shot/{i}.json", "w") as f:
            json.dump(zero_shot_data, f)
        with open(f"apartments/results/gpt-4o/cot/{i}.json", "w") as f:
            json.dump(cot_data, f)


def fix_key_names():
    for i in range(40):
        path = f"apartments/results/gpt-4o/cot/{i}.json"
        with open(path, "r") as f:
            data = json.load(f)
        
        entry = data
        entry["parsed_response"] = entry["parsed_cot_response"]
        del entry["parsed_cot_response"]
        entry["prompt"] = entry["cot_prompt"]
        del entry["cot_prompt"]
        entry["response"] = entry["cot_response"]
        del entry["cot_response"]

        with open(path, "w") as f:
            json.dump(entry, f)


def fix_key_names_zero_shot():
    for i in range(40):
        path = f"apartments/results/gpt-4o/zero_shot/{i}.json"
        with open(path, "r") as f:
            data = json.load(f)
        
        entry = data
        entry["parsed_response"] = entry["parsed_zero_shot_response"]
        del entry["parsed_zero_shot_response"]
        entry["prompt"] = entry["zero_shot_prompt"]
        del entry["zero_shot_prompt"]
        entry["response"] = entry["zero_shot_response"]
        del entry["zero_shot_response"]
        
        with open(path, "w") as f:
            json.dump(entry, f)