import ast
import time
from openai import OpenAI
from dotenv import load_dotenv

load_dotenv()

import pandas as pd
import time
from tqdm import tqdm
import transformers
import torch
import json
import os
import argparse
from sklearn.metrics import confusion_matrix
import torch
import copy
from sld.llm_rule_template import camera_convert_rules
from sld.llm_template import convert_perspetive_prompt


# print(MODEL_SIZE)
# print(LLM_DEVICE)

#TODO: Please setup vLLM to help running this code

def call_llm(messages, model, tokenizer, max_token=8192):
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=max_token,
        temperature=0.000000001
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response


def get_update_prompt(prompt, det_results, full_rules=True, trial_count=3):

    api_key = "EMPTY"
    PORT_NUMBER = os.getenv("vLLM_PORT")
    openai_api_base = f"http://localhost:{PORT_NUMBER}/v1"
    # model = config.get("openai", "model")
    
    # messages = [{"role": "user", "content": message}]

    client = OpenAI(api_key=api_key,base_url=openai_api_base,)
    models = client.models.list()
    model = models.data[0].id


    potential_facing_dir = set()
    if not full_rules:
        if "facing" in prompt:
            # Checking whether there is facing presented in the context
            after_facing_prompt = prompt.split("facing")[1]
            for rule_key in camera_convert_rules.keys():
                if rule_key in after_facing_prompt:
                    potential_facing_dir.add(rule_key)
            
            # Checking in the detection results
            for det_obj in det_results:
                facing_dir = det_obj[-1]
                if facing_dir in camera_convert_rules:
                    potential_facing_dir.add(facing_dir)

    used_rules = list(camera_convert_rules.keys()) if len(potential_facing_dir) == 0 else list(potential_facing_dir)
    rule_text = ""
    rule_index = 5
    for rule_key in used_rules:
        for rule in camera_convert_rules[rule_key]:
            rule_text += f"{rule_index}. {rule}\n"
            rule_index += 1

    update_perspective_prompt = convert_perspetive_prompt.replace("<RULES>", rule_text)
    
    cur_info = "User prompt: {:}\n Current Object: {:}\n".format(prompt, det_results)
    messages = [
        {"role": "user", "content": update_perspective_prompt + cur_info}
    ]
    # print(rule_text)

    while True:
        try:
            response = client.chat.completions.create(model=model, 
                                                    messages=messages,
                                                    extra_body={"chat_template_kwargs": {"enable_thinking": True}},
                                                    timeout=3000,
                                                    max_tokens=8196)
            
            raw_update_prompt = response.choices[0].message.content

            update_prompt = raw_update_prompt.lower().replace("*","")
            # print(len(update_prompt.split("updated prompt")))
            # print(update_prompt.find("updated prompt"))
            update_prompt = update_prompt.split("updated prompt")[1].strip()
            find_first_str = update_prompt.find(":")
            if find_first_str != -1:
                update_prompt = update_prompt[find_first_str+1:].strip()
            break
        except Exception as e:
            # print(data.get("update_prompt_raw", ""))
            print("Error Cannot Find Update Prompt with error:", e)
            if trial_count < 0:
                return prompt
            print(f"Try {trial_count} more times")
            trial_count -= 1
            # update_prompt = raw_update_prompt
    
    return update_prompt



def get_updated_layout(message, config, enable_reasoning=True, trial_count = 5):
    """
    Retrieves a list of objects with updated bounding box coordinates from a given message using a specified model.

    Parameters:
    message (str): The message containing information to process.
    model (str): The language model to use (default is "gpt-4").

    Returns:
    tuple: A tuple containing the list of objects with updated bounding boxes and the complete raw response.
    """
    # Reading configuration from file
    organization = config.get("openai", "organization")
    model = config.get("openai", "model")
    api_key = "EMPTY"
    # api_key = config.get("openai", "API_KEY")
    PORT_NUMBER = os.getenv("vLLM_PORT")
    openai_api_base = f"http://localhost:{PORT_NUMBER}/v1"
    # 
    
    messages = [{"role": "user", "content": message}]
    client = OpenAI(api_key=api_key, base_url=openai_api_base)
    #  
    # ,organization=organization
    models = client.models.list()
    model = models.data[0].id
    while True:
        try:
            response = client.chat.completions.create(model=model, 
                                                    messages=messages,
                                                    extra_body={"chat_template_kwargs": {"enable_thinking": enable_reasoning}},
                                                    timeout=3000,
                                                    max_tokens=8196
                                                    # reasoning_effort= "high"
                                                    )
            
            
            raw_response = response.choices[0].message.content
            # found_last_update_st = raw_response.rindex("python")
            # found_last_update_end = raw_response[found_last_update_st+6:].index("```") + found_last_update_st + 6
            # bbox_data = raw_response[found_last_update_st:found_last_update_end]
            # print(bbox_data)
            bbox_data = raw_response.split("Updated Objects")[1]
            
            start_index = bbox_data.index("[")
            end_index = bbox_data.rindex("]") + 1
            bbox_str = bbox_data[start_index:end_index]
            # print(eval(bbox_str))

            # Converting string to list
            updated_bboxes = ast.literal_eval(bbox_str)
            # print(bbox_data)
            # bbox_data = bbox_data.split("```")[0]
            # print(bbox_data)
            start_index = bbox_data.index("[")
            end_index = bbox_data.rindex("]") + 1
            bbox_str = bbox_data[start_index:end_index]
            # print(eval(bbox_str))

            # Converting string to list
            updated_bboxes = ast.literal_eval(bbox_str)
            print(updated_bboxes)

            return updated_bboxes, raw_response
        except Exception as e:
            print(e)
            if trial_count < 0:
                return None, []
            # time.sleep(10)
            trial_count -= 1

def get_key_objects(message, config, enable_reasoning=True):
    """
    Retrieves key objects and additional negative prompt from a given message using a specified model.

    Parameters:
    message (str): The message to process.
    model (str): The language model to use (default is "gpt-4").

    Returns:
    tuple: A tuple containing the list of key objects, the additional negative prompt, and the complete raw response.
    """

    api_key = "EMPTY"
    PORT_NUMBER = os.getenv("vLLM_PORT")
    openai_api_base = f"http://localhost:{PORT_NUMBER}/v1"
    model = config.get("openai", "model")
    client = OpenAI(api_key=api_key,base_url=openai_api_base,)

    messages = [{"role": "user", "content": message}]

    response = client.chat.completions.create(model=model, 
                                                    messages=messages,
                                                    extra_body={"chat_template_kwargs": 
                                                                {"enable_thinking": enable_reasoning}},
                                                    timeout=1500,
                                                    max_tokens=8192)
    raw_response = response.choices[0].message.content

    # Extracting key objects
    key_objects_part = raw_response.split("Objects:")[1]
    start_index = key_objects_part.index("[")
    end_index = key_objects_part.rindex("]") + 1
    objects_str = key_objects_part[start_index:end_index]

    # Converting string to list
    parsed_objects = ast.literal_eval(objects_str)

    # Extracting additional negative prompt
    bg_prompt = raw_response.split("Background:")[1].split("\n")[0].strip()
    negative_prompt = raw_response.split("Negation:")[1].strip()

    parsed_result = {
        "objects": parsed_objects,
        "bg_prompt": bg_prompt,
        "neg_prompt": negative_prompt,
    }
    return parsed_result, raw_response
