#!/usr/bin/env python
# coding: utf-8

# In[ ]:
import openai
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
import concurrent.futures
import copy
import GPy

# Optimization iteration function
def random_search(X, Y, xx, theta, num_iter, noise_intensity, d):
    
    all_rewards = []
    
    for itr in np.arange(num_iter):
        x_t_1, x_t_2 = xx[np.random.choice(xx.shape[0], 2, replace=False), :]
        y_t = observe_duel(xx, x_t_1, x_t_2, theta, noise_intensity, d)
        # if X.ndim == 1:
        #     X = X.reshape(-1, 1)
        # X = np.concatenate((X, (x_t_1 - x_t_2).reshape(1, -1)), axis=0)
        Y = np.append(Y, y_t)

        rewards = [compute_f(xx, x_t_1, theta, d), compute_f(xx, x_t_2, theta, d)]
        all_rewards.append(rewards)
    flattened_arrays = np.array(all_rewards).flatten()
    all_rewards = flattened_arrays.reshape(-1, 2)
    # print(all_rewards)
    
    return all_rewards

def optimize_experiment(X, Y, N_repeat_arm_1, xx, theta, num_iter, noise_intensity, d):
    
    all_rewards = []
    
    for itr in np.arange(num_iter):
        
        acq_arm_1_all = []  # Acquisition function: stores expected probabilities of each arm
        for k in range(N_repeat_arm_1):
            acq_arm_1 = []  # Stores probabilities of each arm relative to the current random arm
            random_row_index = np.random.choice(xx.shape[0])
            rand_arm = xx[random_row_index]
            for j in xx:
                
                if using_difference:
                    concat_features = np.asarray(j - rand_arm)
                else:
                    concat_features = np.concatenate((j, rand_arm))
                
                prompt = make_prompt_dueling_experiment(X, Y, concat_features)
                # print("prompt1:", prompt)
                msg = get_valid_msg(prompt, itr, model=gpt_model)  # Get ChatGPT response
                if msg is None:
                    print("No valid message received")
                # print("msg acq1:", msg)

                acq_arm_1.append(np.asarray(msg, dtype=float))
            acq_arm_1_all.append(acq_arm_1)
                
        acq_arm_1_all = np.array(acq_arm_1_all)
        acq_arm_1_all_mean = np.mean(acq_arm_1_all, axis=0)
        # Find all indices of the maximum values
        max_indices = np.where(acq_arm_1_all_mean == np.max(acq_arm_1_all_mean))[0]
        # Randomly select one of the maximum value indices
        x_t_1 = xx[np.random.choice(max_indices)]
        
        acq_arm_2 = []  # Acquisition function: stores probabilities of each arm relative to the first arm
        for j in xx:
            
            if using_difference:
                concat_features = np.asarray(j - rand_arm)
            else:
                concat_features = np.concatenate((j, rand_arm))
            
            prompt = make_prompt_dueling_experiment(X, Y, concat_features)
            # print("prompt2:", prompt)
            msg = get_valid_msg_large_t(prompt, itr, model=gpt_model)  # Get ChatGPT response, using a slightly larger t for selecting the second arm
            if msg is None:
                print("No valid message received")
            # print("msg acq2:", msg)

            acq_arm_2.append(np.asarray(msg, dtype=float))
        acq_arm_2 = np.squeeze(acq_arm_2)
        # Find all indices of the maximum values
        max_indices = np.where(acq_arm_2 == np.max(acq_arm_2))[0]
        # Randomly select one of the maximum value indices
        x_t_2 = xx[np.random.choice(max_indices)]
        
        y_t = observe_duel(xx, x_t_1, x_t_2, theta, noise_intensity, d)  # Sampling
        
        if using_difference:
            if X.ndim == 1:
                X = X.reshape(-1, 1)
            X = np.concatenate((X, (x_t_1 - x_t_2).reshape(1, -1)), axis=0)
        else:
            concatenated_feature = np.concatenate((x_t_1, x_t_2))
            X = np.concatenate((X, concatenated_feature.reshape(1, -1)), axis=0)
        
        Y = np.append(Y, y_t)
        # print("X:", X)
        # print("Y:", Y)

        rewards = [compute_f(xx, x_t_1, theta, d), compute_f(xx, x_t_2, theta, d)]
        all_rewards.append(rewards)
        print("Iteration:", itr)
        # print("rewards:", rewards)

    return all_rewards

