import io
import openai.error
import requests
import random
import base64
import imageio
from PIL import Image
import numpy as np
import openai, time

def check_response(skill_response, incremental_skill_response):
    original_lines = skill_response.split(' ')
    original_indices = [int(line.split('.')[0]) for line in original_lines if '.' in line and line.split('.')[0].isdigit()]

    if incremental_skill_response == "gpt_error":
        incre = [str(random.choice(original_indices))]
        incre.append(skill_response)
        incremental_skill_response = ";".join(incre)

    try:
        add_skills = incremental_skill_response.split(';', 1)[1].strip()
        add_lines = add_skills.split(' ')
        add_indices  = [int(line.split('.')[0]) for line in add_lines if '.' in line and line.split('.')[0].isdigit()]
    except IndexError:
        add_lines = skill_response.split(' ')
        add_indices  = [int(line.split('.')[0]) for line in add_lines if '.' in line and line.split('.')[0].isdigit()]

    if len(add_indices) > len(original_indices) + 1:
        # add_lines = original_lines + add_lines[len(original_lines):len(original_lines) + 1]
        add_lines = add_lines[:add_lines.index(str(len(original_indices) + 2)+'.')]

    try:
        selected_index = int(incremental_skill_response.split(';')[0].strip())
    except ValueError:
        selected_index = 0  
        
    valid_indices = [i + 1 for i in range(min(len(add_indices), len(original_indices) + 1))] # range start from 0
    if selected_index not in valid_indices:
        try:
            selected_index = random.choice(valid_indices)  # Re-sample a valid index
        except IndexError:
            selected_index = random.choice(original_indices)

    add_lines.insert(0, f"{selected_index};")

    return ' '.join(add_lines)




def extract_valid_indice(llm_response):
    try:
        selected_index = int(llm_response.split(';')[0].strip())
    except ValueError:
        selected_index = 1  # Handle parsing error by setting to an invalid index
    
    try:
        skills = llm_response.split(';', 1)[1].strip()
        lines = skills.split(' ')
        valid_indices = [int(line.split('.')[0]) for line in lines if '.' in line and line.split('.')[0].isdigit()]
    except Exception as e:
        valid_indices = []
    if selected_index not in valid_indices:
        try:
            selected_index = random.choice(valid_indices)  # Re-sample a valid index
        except IndexError:
            selected_index = 1
            
    return selected_index - 1 ## as the llm skill set index starts from 1 


def read_prompt_from_file(filename):
    try:
        with open(filename, 'r') as file:
            content = file.read().strip()
            if ';' in content and ' ' in content:
                # Extract text after the first semicolon
                skills_content = content.split(';', 1)[1].strip()
                skills_content = skills_content.replace('\n', ' ').replace('\t', ' ')
                return skills_content.strip()  # Ensures skills_content = skills_content.replace('\n', ' ').replace('\t', ' ')no leading or trailing spaces
            else:
                return ""
    except FileNotFoundError:
        return ""

def save_response_to_file(response, filename):
    with open(filename, 'w') as file:
        file.write(response)
    with open("", 'a') as file:  
        file.write("\n" + response) 

class gpt_agent():
    def __init__(self):
        openai.api_base = "https://api.ai-gaochao.cn/v1"
        self.key_list = []
        openai.api_key = random.choice(self.key_list)
        self.init_ask_call_cnt = 0
        self.init_ask_call_cnt_sup = 3

        self.term_ask_call_cnt = 0
        self.term_ask_call_cnt_sup = 3

    def init_ask(self, text1, initial_img_base64):
        self.init_ask_call_cnt = self.init_ask_call_cnt + 1
        
        if self.init_ask_call_cnt > self.init_ask_call_cnt_sup:
            print("======> Achieve call count limit, Return!")
            self.ask_call_cnt = 0
            return "gpt_error"
        
        try:
            initial_response = openai.ChatCompletion.create(
                model="gpt-4o",
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": f"{text1}"},
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/jpeg;base64,{initial_img_base64}"
                                }
                            },
                        ]
                    }
                ]
            )

        except openai.error.APIError as e:
            print('======> APIError error: will retry after 10 seconds')
            time.sleep(10)
            openai.api_key = random.choice(self.key_list)
            return self.init_ask(text1, initial_img_base64)
        
        except openai.error.ServiceUnavailableError:
            print('======> Service unavailable error: will retry after 10 seconds')
            time.sleep(10)
            return self.init_ask(text1, initial_img_base64)
        return initial_response
    
    def term_ask(self, text1, initial_img_base64, initial_response_content, text2, terminal_img_base64):
        self.init_ask_call_cnt = self.init_ask_call_cnt + 1
        
        if self.init_ask_call_cnt > self.init_ask_call_cnt_sup or initial_response_content == "gpt_error":
            print("======> Achieve call count limit, Return!")
            self.ask_call_cnt = 0
            return "gpt_error"
        try:
            reasoning_response = openai.ChatCompletion.create(
            model="gpt-4o",
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": f"{text1}"},
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{initial_img_base64}"
                            }
                        },
                    ]
                },
                {
                    "role": "assistant",
                    "content": [
                        {
                            "type": "text",
                            "text": initial_response_content
                        }
                    ]
                },
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": f"{text2}"},
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{terminal_img_base64}"
                            }
                        },
                    ]
                }
            ]
        )

        except openai.error.APIError as e:
            print('======> APIError error: will retry after 10 seconds')
            time.sleep(10)
            openai.api_key = random.choice(self.key_list)
            return self.term_ask(text1, initial_img_base64, initial_response_content, text2, terminal_img_base64)
        return reasoning_response

#### preset some skills for ease of warming up 
def query_llm(initial_image_path, terminal_image_path, 
              skill_path = ''):
    query_agent = gpt_agent()

    # Define the initial and terminal state images
    initial_obs = imageio.mimread(initial_image_path)[0]
    terminal_obs = imageio.mimread(terminal_image_path)[0]

    initial_image = Image.fromarray(initial_obs)
    buffer = io.BytesIO()
    initial_image.save(buffer, format="JPEG")
    buffer.seek(0)
    initial_img_base64 = base64.b64encode(buffer.read()).decode('utf-8')

    terminal_image = Image.fromarray(terminal_obs)
    buffer = io.BytesIO()
    terminal_image.save(buffer, format="JPEG")
    buffer.seek(0)
    terminal_img_base64 = base64.b64encode(buffer.read()).decode('utf-8')

    # Read the prompt content from the text file
    skill_file = skill_path
    skill_content = read_prompt_from_file(skill_file)

    if not skill_content:
        skill_content = "1. No obvious skill 2. Open the microwave door 3. Close the cabinet 4. Turn on the stove burner 5. Move the kettle"
    
    ############################ kitchen
    text1 = f" Are you familiar with the Franka Kitchen environment used for reinforcement learning? In this environment, various skills represent different subtasks.\
            You can deduce which skill was executed or identify the agent's subprocess based on the initial and terminal observations. I will now show you an image of the initial state.\
            Here are examples of possible subtasks already summarized: {skill_content}\
            From my perspective you can reason what the skill is according to the initial observation and terminal observation. Now i will show you the image, it is the initial state. \
            If you think the task happened between two states in the summarized list, point the index out. remember maybe many changes happen between two images, choose the most suitable skill name. If not in them, \
            and you still regard it as a new skill, try your best to give only one suitable new name. Each time, you can provide only one new name if necessary and carefully ensure that it is truly different from the names in the original set. In total, You can describe up to 20 skills."
    text2 = "Now I will show you the image of the terminal state, which means it is the last frame of this segment. You should first judge what is the most obvious one if many changes happen. Check the difference and choose the most probable one. \
            Please determine whether to choose from the given subtask list or name yourself. Write down your choice with the following format: the index selected by you; all the skills index and content. For example 1;\n 1.content of 1\n 2.content of 2 etc. Do not need to display your reasoning process."
     
    # First API call to set initial state context

    initial_response = query_agent.init_ask(text1, initial_img_base64)

    # Extract the response content to include it in the second call
    if not initial_response == "gpt_error":
        initial_response_content = initial_response['choices'][0]['message']['content']
    else: 
        initial_response_content = initial_response
    reasoning_response = query_agent.term_ask(text1, initial_img_base64, initial_response_content, text2, terminal_img_base64)
    # Second API call to provide terminal state and get the reasoning

    if not reasoning_response == "gpt_error":
        final_response_content = reasoning_response['choices'][0]['message']['content']
        final_response_content = final_response_content.replace('\n', ' ').replace('\t', ' ')
    else: 
        final_response_content = reasoning_response

    final_response_content = check_response(skill_content, final_response_content)
    # Save the final response to a text file
    save_response_to_file(final_response_content, skill_path)

    return final_response_content

if __name__ == "__main__":
    initial_image_path = ''
    terminal_image_path = ''
    response = query_llm(initial_image_path, terminal_image_path)
    print(response)