#!/usr/bin/env python
# coding: utf-8

# In[ ]:
import openai
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
import concurrent.futures
import copy
import GPy

# This is an auxiliary function to help clean up the response generated by the LLM
def remove_extra_dots(input_string):
    # Split the string at the first occurrence of "."
    first_part, *remaining_parts = input_string.split('.', 1)

    # If there's no "." in the string, return it as is
    if not remaining_parts:
        return input_string

    # Concatenate the first part with the first occurrence of "." and remove extra "." from the remaining parts
    result = first_part + '.' + remaining_parts[0].replace('.', '')

    return result

# Get a valid message, ensuring the return value is between 0 and 1 (inclusive). Try up to max_retries times.
def get_valid_msg(prompt, itr, model="gpt-3.5-turbo", max_retries=10):

    retries = 0
    while retries < max_retries:
        # Get the model's response
        # response = get_chatgpt_response(prompt, model=model)
        response = get_chatgpt_response_variable_t(prompt, itr, model=model)  # Variable temperature
        msg = response.choices[0].message.content
        # Clean msg, removing non-numeric characters and extra dots
        msg = ''.join(filter(lambda x: x.isdigit() or x == '.', msg))
        msg = remove_extra_dots(msg)
        # Check if msg is within the valid range (0 <= msg <= 1)
        try:
            msg_float = float(msg)  # Attempt to convert msg to a float
            if 0 <= msg_float <= 1:  # If msg is within the valid range
                return msg_float  # Return the valid message
        except ValueError:
            pass  # If conversion to float fails, continue retrying
        
        retries += 1  # Increase retry count
        print(f"Attempt {retries}, received invalid result: {msg}, retrying...")

    # If the maximum retry count is exceeded and still invalid, return None or another appropriate value
    print(f"Exceeded maximum retry attempts, returning invalid result")
    return None

def get_valid_msg_large_t(prompt, itr, model="gpt-3.5-turbo", max_retries=10):

    retries = 0
    while retries < max_retries:
        # Get the model's response
        # response = get_chatgpt_response_t_large(prompt, model=model)
        response = get_chatgpt_response_variable_t_large(prompt, itr, model=model)  # Large variable temperature
        msg = response.choices[0].message.content
        # Clean msg, removing non-numeric characters and extra dots
        msg = ''.join(filter(lambda x: x.isdigit() or x == '.', msg))
        msg = remove_extra_dots(msg)
        # Check if msg is within the valid range (0 <= msg <= 1)
        try:
            msg_float = float(msg)  # Attempt to convert msg to a float
            if 0 <= msg_float <= 1:  # If msg is within the valid range
                return msg_float  # Return the valid message
        except ValueError:
            pass  # If conversion to float fails, continue retrying
        
        retries += 1  # Increase retry count
        print(f"Attempt {retries}, received invalid result: {msg}, retrying...")

    # If the maximum retry count is exceeded and still invalid, return None or another appropriate value
    print(f"Exceeded maximum retry attempts, returning invalid result")
    return None

def make_prompt_dueling_experiment(X, Y, x_test):

    # prompt = "I have a 2-class classification problem. Help me predict the probability of the last input belonging to class 1.  \
    # Your response MUST only contain the predicted probability value in the format of #predicted probability value#.\n"
    prompt = "Help me predict the value for the last input as a continuous value between 0 and 1.\
    Your response MUST only contain the value in the format of #value#.\n"
    
    # Compute the distance between each sample in X and x_test (Euclidean distance)
    if len(X.shape) > 1:  # Multi-dimensional data (each sample has multiple features)
        X_ = np.linalg.norm(X - x_test, axis=1)  # axis=1 computes Euclidean distance row-wise
    else:  # One-dimensional data (only one feature per sample)
        X_ = np.abs(X - x_test)  # Directly compute absolute difference
    sorted_indices_descending = sorted(range(len(X_)), key=lambda i: X_[i], reverse=True)  # Sort by distance in descending order

    s = ""
    # for i in np.arange(len(X)):
    for i in sorted_indices_descending:
        s += f"input: {X[i]}, output: {Y[i]}\n"
    s += f"input: " + str(x_test) + ", output: "
    prompt += s

    return prompt

