from xml.parsers.expat import model
from openai import OpenAI
import base64
import requests
from time import sleep
import logging
import json
import csv
import sys
from src.data_structure import yaml
import os
import dill
import re
from collections import defaultdict
from copy import deepcopy

# OpenAI API Key
api_key=os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

# General utils

def load_from_file(fpath, noheader=True):
    ftype = os.path.splitext(fpath)[-1][1:]
    if ftype == 'pkl':
        with open(fpath, 'rb') as rfile:
            out = dill.load(rfile)
    elif ftype == 'txt':
        with open(fpath, 'r') as rfile:
            if 'prompt' in fpath:
                out = "".join(rfile.readlines())
            else:
                out = [line.strip() for line in rfile.readlines()]
    elif ftype == 'json':
        with open(fpath, 'r') as rfile:
            out = json.load(rfile)
    elif ftype == 'csv':
        with open(fpath, 'r') as rfile:
            csvreader = csv.reader(rfile)
            if noheader:
                fileds = next(csvreader)
            out = [row for row in csvreader]
    elif ftype == 'yaml':
        with open(fpath, 'r') as rfile:
            out = yaml.load(rfile, Loader=yaml.FullLoader)
    else:
        raise ValueError(f"ERROR: file type {ftype} not recognized")
    return out

def save_to_file(data, fpth, mode=None):
    ftype = os.path.splitext(fpth)[-1][1:]
    if ftype == 'pkl':
        with open(fpth, mode if mode else 'wb') as wfile:
            dill.dump(data, wfile)
    elif ftype == 'txt':
        with open(fpth, mode if mode else 'w') as wfile:
            wfile.write(data)
    elif ftype == 'json':
        with open(fpth, mode if mode else 'w') as wfile:
            json.dump(data, wfile, sort_keys=True,  indent=4)
    elif ftype == 'csv':
        with open(fpth, mode if mode else 'w', newline='') as wfile:
            writer = csv.writer(wfile)
            writer.writerows(data)
    elif ftype == 'yaml':
        with open(fpth, 'w') as wfile:
            yaml.dump(data, wfile)
    else:
        raise ValueError(f"ERROR: file type {ftype} not recognized")

# def encode_image(image_path):
#   with open(image_path, "rb") as image_file:
#     return base64.b64encode(image_file.read()).decode('utf-8')

def raw_prompt(prompt, img_path_list=[], model="gpt-5"):
    api_key=os.getenv("OPENAI_API_KEY")
    client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))  

    content = [{"type": "text", "text": prompt}]

    for image_path in img_path_list:
        with open(image_path, "rb") as f:
            image_bytes = f.read()
        base64_image = base64.b64encode(image_bytes).decode("utf-8")
        content.append(
            {
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/png;base64,{base64_image}"
                },
            }
        )

    response = client.chat.completions.create(
        model=model,
        messages=[
                {
                    "role": "user",
                    "content": content
                }
            ],
        )
    
    return response.choices[0].message.content

def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
    return encoded_image

def prompt2msg(query_prompt, vision=False):
    """
    Make prompts for GPT-3 compatible with GPT-3.5 and GPT-4.
    Support prompts for
        RER: e.g., data/osm/rer_prompt_16.txt
        symbolic translation: e.g., data/prompt_symbolic_batch12_perm/prompt_nexamples1_symbolic_batch12_perm_ltl_formula_9_42_fold0.txt
        end-to-end translation: e.g., data/osm/osm_full_e2e_prompt_boston_0.txt
    :param query_prompt: prompt used by text completion API (text-davinci-003).
    :return: message used by chat completion API (gpt-3, gpt-3.5-turbo).
    """
    prompt_splits = query_prompt.split("\n\n") if type(query_prompt) == str else query_prompt
    
    task_description = prompt_splits[0]
    examples = prompt_splits[1: -1]
    query = prompt_splits[-1]

    tag = "text" if vision else "content"
    msg = [{"role": "system", tag: task_description}]
    if len(prompt_splits) > 1:
        msg.append({"role": "user", tag: "\n\n".join(prompt_splits[1:])})

    return msg

class GPT4:
    def __init__(self, engine="gpt-5-nano", temp=0, max_tokens=500, n=1, stop=['\n\n\n']):
        self.engine = engine
        self.temp = temp
        self.max_tokens = max_tokens
        self.n = n
        self.stop = stop
    # @retry(wait_fixed=15000, stop_max_attempt_number=5)
    def generate(self, query_prompt):
        '''query_prompt: query with task description and in-contex examples splited with \n\n'''
        complete = False
        ntries = 0
        if 'o1' in self.engine:
            msg = [{
                "role": "user",
                "content": query_prompt
            }]
            while not complete and ntries < 15:
                try:
                    raw_responses = client.chat.completions.create(model=self.engine,
                    messages=msg,
                    # temperature=self.temp,
                    # n=self.n,
                    # stop=self.stop,
                    # max_completion_tokens=self.max_tokens,
                    )
                    complete = True
                except Exception as e:
                    # print(f"Response content: {response.content}")
                    print(e)
                    sleep(10)
                    print(f"{ntries}: waiting for the server. sleep for 10 sec...")
                    logging.info(f"{ntries}: waiting for the server. sleep for 10 sec...")
                    # logging.info(f"{ntries}: waiting for the server. sleep for 30 sec...\n{query_prompt}")
                    logging.info("OK continue")
                    ntries += 1
        else:
            msg = prompt2msg(query_prompt)

            while not complete and ntries < 15:
                try:
                    raw_responses = client.chat.completions.create(model=self.engine,
                    messages=msg,
                    temperature=self.temp,
                    n=self.n,
                    stop=self.stop,
                    max_tokens=self.max_tokens)
                    complete = True
                except Exception as e:
                    # print(f"Response content: {response.content}")
                    print(e)
                    sleep(10)
                    print(f"{ntries}: waiting for the server. sleep for 10 sec...")
                    logging.info(f"{ntries}: waiting for the server. sleep for 10 sec...")
                    # logging.info(f"{ntries}: waiting for the server. sleep for 30 sec...\n{query_prompt}")
                    logging.info("OK continue")
                    ntries += 1
        if self.n == 1:
            responses = [raw_responses.choices[0].message.content.strip()]
        else:
            responses = [choice["message"]["content"].strip() for choice in raw_responses.choices]
        return responses
    
    # @retry(wait_fixed=15000, stop_max_attempt_number=5)
    def generate_multimodal(self, query_prompt, imgs, max_tokens=1500, logprobs=False, temp = None):
        '''separate function on purpose to call multimodal API. It will have the function to have mixed but ordered img & text input'''
        complete = False
        ntries = 0

        headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
        }
        # breakpoint()
        txts = prompt2msg(query_prompt, vision=True)
        if "gpt-4o" in self.engine:
            payload = {
                "model": self.engine,
                "messages": [],
                "max_tokens": max_tokens,
                "logprobs": logprobs,
                "temperature": self.temp
                }
            if temp: payload["temperature"] = temp
        else:
            payload = {
                "model": self.engine,
                "messages": [],
                "logprobs": False
                }

        assert not (payload["logprobs"] and self.engine == "o1") # o1 doesn't support temperature
        if logprobs:
            payload["top_logprobs"] = 2
        msg = {"role": "user", "content": []}
        messages = []
        for line_txt in txts:
            # line_txt["type"] = "text"
            # msg["content"].append(line_txt)
            messages.append(
                {
                    "role": line_txt["role"],
                    "content":[
                        {
                            "type": "text",
                            "text": line_txt["text"]
                        }
                    ]
                }
            )
        for img in imgs:
            base64_img = encode_image(img)
            messages[-1]["content"].append( # append to the last message
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/png;base64,{base64_img}"
                    },
                }
            )

        complete = False
        ntries = 0

        # breakpoint()
        while not complete and ntries < 15:
            try:
                response = client.chat.completions.create(
                    model=self.engine,
                    messages=messages,
                    )
                # raw_responses = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload).json()
                complete = True
                # sleep(0.2)
            except Exception as e:
                print(e)
                sleep(5)
                print(f"{ntries}: waiting for the server. sleep for 5 sec...")
                print(f"The prompt right now is:\n\n{query_prompt}")
                logging.info(f"{ntries}: waiting for the server. sleep for 5 sec...")
                logging.info(f"The prompt right now is:\n\n{query_prompt}")
                # logging.info(f"{ntries}: waiting for the server. sleep for 30 sec...\n{query_prompt}")
                logging.info("OK continue")
                ntries += 1

        if self.n == 1:
            # responses = [raw_responses["choices"][0]["message"]["content"].strip()]
            responses = [response.choices[0].message.content]
        else:
            # responses = [choice["message"]["content"].strip() for choice in raw_responses["choices"]]
            responses = [choice["message"]["content"].strip() for choice in response["choices"]]
        # breakpoint()
        return responses

def get_save_fpath(directory: str, fname: str, ftype: str) -> str:
    counter = 1
    while True:
        save_path = f"{directory}/{fname}_{counter}.{ftype}"
        if not os.path.exists(save_path):
            break
        counter += 1
    return save_path

def setup_logging(dir_name, env_name) -> str:
    os.makedirs(dir_name, exist_ok=True)
    save_path = get_save_fpath(dir_name, f"{env_name}_log_raw_results", "log")
    # Remove all handlers associated with the root logger
    root_logger = logging.getLogger()
    if root_logger.hasHandlers():
        root_logger.handlers.clear()
    logging.basicConfig(level=logging.INFO,
                        format='%(message)s',
                        handlers=[
                            logging.FileHandler(save_path, mode='w'),
                            logging.StreamHandler()
                        ]
    )
    logging.info(f"log files will be saved at {save_path}")

    return save_path

def clean_logging(save_path, keyword_list=['HTTP']):
    """
    There will be redundant messages with certain prefix in the log file.
    This function will naively remove the lines that contains thos prefix
    """
    # Clean the lines start with 'HTTP'
    with open(save_path, 'r') as file:
        lines = file.readlines()
    with open(save_path, 'w') as file:
        for line in lines:
            if not any(kw in line for kw in keyword_list):
                file.write(line)

def load_tasks(load_path, task_config):
    """
    tasks are saved separately since predicate invention takes it as input.

    Args:
        task_config :: dict('env': str, 'skills': Skill, 'objects': dict, 'Env_description': str, 'Initial_observation': dict)
        tasks :: dict(task_name: (step: dict("skill": grounded_skill, 'image':img_path, 'success': Bool)))
    """
    tasks = load_from_file(f"{load_path}/tasks.yaml")
    converted_tasks = defaultdict(dict)
    # reassembly skill string to skill objects
    for task_name, task_meta in tasks.items():
        task_meta = {str(k): v for k, v in sorted(task_meta.items(), key=lambda x: int(x[0]))}
        for step in task_meta:
            converted_tasks[task_name][int(step)] = deepcopy(tasks[task_name][step])
            if not int(step) == 0:
                skill_string = str(tasks[task_name][step]["skill"])
                match = re.match(r"(\w+)\((.*)\)", skill_string.strip())
                skill_name = match.group(1)
                parameters = match.group(2).split(", ")
                # assuming every different skill has different names
                lifted_skill = [skill for skill in task_config['skills'].values() if skill.name==skill_name][0]
                converted_tasks[task_name][int(step)]["skill"] = lifted_skill.ground_with(parameters)

    return converted_tasks

def save_results(skill2operator, lifted_pred_list, grounded_predicate_truth_value_log, save_directory):
    """
    Save results as seprate yaml file. 
    Operators will be save in both pkl files for later usage and string in yaml files for human readability.

    Args:
        grounded_predicate_truth_value_log {task_name:{step:PredicateState}}
        lifted_pred_list :: list[Predicate]
        skill2operator :: {lifted_skill: [(LiftedPDDLAction, {pid: int: type: str})]}
    """
    # iters = os.listdir(save_directory)
    # # largest iteration number
    # iter_idx = max([int(i) for i in iters if i.isdigit()])
    # save_directory = f"{save_directory}/{iter_idx}"

    for subdir in ['operators', 'predicates']:
        os.makedirs(f"{save_directory}/{subdir}", exist_ok=True)

    save_to_file(skill2operator, f"{save_directory}/operators/skill2operator.pkl")
    readable_operators = {lifted_skill: [str(operator_meta[0]) for operator_meta in operator_metas] for lifted_skill, operator_metas in skill2operator.items()}
    save_to_file(readable_operators, f"{save_directory}/operators/operators.yaml")

    save_to_file(lifted_pred_list, f"{save_directory}/predicates/predicates.yaml")

    save_to_file(grounded_predicate_truth_value_log, f"{save_directory}/transitions/grounded_predicate_truth_value_log.yaml")

    logging.info(f"results have been saved to {save_directory}")

def load_results(load_fpath, task_config):
    """
    Load tasks, operators, predicate list, and truth value log
    """
    try:
        tasks = load_tasks(f"{load_fpath}/transitions", task_config)
    except:
        tasks = {}

    try:
        skill2operator = load_from_file(f"{load_fpath}/operators/operator.pkl")
    except:
        skill2operator = {lifted_skill: None for lifted_skill in list(task_config['skills'].values())}

    try:
        lifted_pred_list = load_from_file(f"{load_fpath}/predicates/predicates.yaml")
    except:
        lifted_pred_list = []

    try:
        grounded_predicate_truth_value_log = load_from_file(f"{load_fpath}/transitions/grounded_predicate_truth_value_log.yaml")
    except:
        grounded_predicate_truth_value_log = {}
    
    # TODO: add logging when loading data
    logging.info(f"loaded results from {load_fpath}")
    
    return tasks, skill2operator, lifted_pred_list, grounded_predicate_truth_value_log

def init_new_iter(env, method, run_idx):
    # prepare folder structures
    run_dir = f"results/{method}/{env}/runs/{run_idx}"
    os.makedirs(run_dir, exist_ok=True)
    iters = os.listdir(run_dir)
    iters = [int(i) for i in iters if i.isdigit()]
    if iters: # previous iterations exist
        # largest iteration number
        iter_idx = max(iters) if iters else 0
        new_iter_idx = iter_idx + 1
        new_iter_dir = f"{run_dir}/{new_iter_idx}_partial/"

        # copy transitions of previous iteration and rename to new iteration
        prev_iter_dir = f"{run_dir}/{iter_idx}"
        os.system(f"cp -r {prev_iter_dir} {run_dir}/{new_iter_idx}_partial")
        
    else:
        new_iter_idx = 0
        new_iter_dir = f"{run_dir}/{new_iter_idx}_partial"
        os.makedirs(new_iter_dir, exist_ok=True)
        for subdir in ['operators', 'predicates', 'transitions', 'skill_sequences', 'log']:
            os.makedirs(f"{new_iter_dir}/{subdir}", exist_ok=True)

    return new_iter_dir

if __name__ == "__main__":
    gpt = GPT4(engine="o1")
    # imgs = ["test_imgs/test_0.png", "test_imgs/test_1.png"]
    # imgs = ["test_imgs/pickup.png"]
    # imgs = ["test_imgs/success.png", "test_imgs/failure.png"]
    # imgs = ["test_imgs/failure.png"]
    # imgs = ["test_imgs/caption.png"]
    # imgs = ["test_imgs/1.jpg", "test_imgs/3.jpg"]
    imgs = ['tasks/exps/GoTo/After_GoTo_DiningTable_CoffeeTable_True_2.jpg']
    imgs = ['tasks/exps/GoTo/After_GoTo_CoffeeTable_DiningTable_True_2.jpg']
    # txt = "The robot is executing pickup() action. There are certain PDDL predicates that are related to the task. Please propose them."
    # txt = "The robot exectued an action called pickup(Apple). The two images are egocentric observation of the robot before and after the execution. Can you tell which one is before and which one is after execution?"
    # txt = 'A robot is executing tasks in the envrionment. Here is what the robot sees from an egocentric view. Please provide a general description of the type of the environment, such as household or facotry, and the robots, such as its mobility and embodiement.'
    # txt = "You are a robot,  and all the images are exactly what you see. You are commanded to execute PickUp() action, and the first image shows a successful attempt, while the second one is a failure. How can you guide the robot from the failure image to the successful image using the provided actions: MoveGripperLeft, MoveGripperRight, MoveGripperForward, MoveGripperBackward, MoveGripperUp, MoveGripperDown?"
    # txt = 'If this is what you see exactly, in which picture the robot gripper is on the left of the image? Is it the first one or the second?'
    # txt = "If this is what you see exactly, is the bread located on the left of the table or the right in both pictures?"
    # txt = 'What is on the left of the table?'
    # txt = "If this is what you see exactly, is the bread on the left of the gripper or on the right?"
    # txt = "You are a mobile manipulator robot. Is this image from a first-person view or a third-person view?"
    # txt = "If these images are what you see exactly, from the second image to the first image, which direction did the gripper move to, left or right?"
    
    # txt = "You are a robot,  and all the images are exactly what you see. You are commanded to execute PickUp() action, how can you guide the robot from the second image to the first image using the provided actions: MoveGripperLeft, MoveGripperRight, MoveGripperForward, MoveGripperBackward, MoveGripperUp, MoveGripperDown?"
    # txt = "You are a robot,  this image is exactly what you see right now. You are commanded to execute PickUp() action, and you have two actions available: MoveGripperLeft() amd MoveGripperRight(), if this image is exactly what you are seeing from your eyes, what would be your next action?"
    # txt = "What are the objects in this picture?"
    # txt = """
    # There are certain predicates associated with different skills, please find out the ones that have changed their truth value by comparing the visual observation before and after the execution of the skill. You only have to fill the predicates and truth values before and after the execution on the "Effect" line without explanation. Note that not all predicates are necessary to form the effect, you should only select the most essential ones based on the visual observation.
    # Skill: PickUp(object, location)
    # Predicates: 'AtLocation(object,location)', 'Holding(object)', 'At(location)', 'IsReachable(object)', 'IsFreeHand()'
    # object = Book, location = Table
    # """
    # txt = "A robot is executing a skill PickUp(Book, DiningTable). Given the following egocentric observation from the robot, what is the truth value of the predicate HasEmptyHands()? Answer with reasoning and True or False in a separate line. Note that the blue sphere on the gripper is a part of the gripper and is not an object, and it's a simulated environment so you can only tell if an object is grasped by determining if it is moved to the air by the gripper. Also, do not assume the object is in the scene.\nHasEmptyHand(): The robot's hands are empty and not holding anything."
    txt = "A robot is executing a skill GoTo(CoffeeTable, DiningTable). Given the following egocentric observation from the robot, what is the truth value of the predicate isClose(CoffeeTable)? Answer with reasoning and True or False in a separate line. Note that the blue sphere on the gripper is a part of the gripper and is not an object, and it's a simulated environment so you can only tell if an object is grasped by determining if it is moved to the air by the gripper. Also, do not assume the object is in the scene.\nisClose(loc): The robot is physically close to the location `loc`."
    responses = gpt.generate_multimodal(txt, imgs,logprobs=True)
    # responses = gpt.generate_multimodal(txt, imgs)
    # responses = gpt.generate(txt)
    print(responses)