from abc import ABC
import json, requests
import time
from abc import abstractmethod
from random import random

import numpy as np
from openai import OpenAI
from groq import Groq
import weave
import wandb
import config
import utils
from ExpConfig.ExpConfig import PromptConfig
from config import *
from tqdm import tqdm
from together import Together
from cerebras.cloud.sdk import Cerebras
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from utils import get_latest_modified_file_with_prefix
from concurrent.futures import ThreadPoolExecutor, as_completed
from Testers.cache_manager import CacheManager, cache_response
import base64
from PIL import Image
import io
import re
from typing import List, Tuple


def cast_type_from_str(input):
    if isinstance(input, str):
        return int(input)
    else:
        return input

def image_to_base64_data_url(image_path_or_array, image_format="png"):
    if isinstance(image_path_or_array, str):
        # It's a path
        pil_img = Image.open(image_path_or_array).convert("RGB")
    elif isinstance(image_path_or_array, np.ndarray):
        # It's a NumPy array
        pil_img = Image.fromarray(image_path_or_array).convert("RGB")
    else:
        raise ValueError("Input must be image path (str) or numpy array")

    buffered = io.BytesIO()
    pil_img.save(buffered, format=image_format)
    # The 'data:image/png;base64,' prefix is crucial for OpenAI Vision API compatibility
    return f"data:image/{image_format};base64,{base64.b64encode(buffered.getvalue()).decode('utf-8')}"


def extract_and_split_paths(input_str: str) -> List[Tuple[str, str]]:
    """
    Finds patterns like <PATH>{path}</PATH>, extracts the path,
    and splits the string into a list of (substring, label) tuples.

    Args:
        input_str: The string to process.

    Returns:
        A list of tuples, where each tuple contains a substring and its
        label ('text' or 'path').
    """
    # The regex pattern finds <PATH>...</PATH> and captures the content (.*?) inside.
    # The '?' makes the capture non-greedy.
    pattern = r'<PATH>(.*?)</PATH>'

    # re.split with a capturing group splits the string by the full pattern
    # but keeps the captured group (the path) in the result list.
    # Example: 'text1<PATH>p1</PATH>text2' -> ['text1', 'p1', 'text2']
    parts = re.split(pattern, input_str)

    result = []
    # The list alternates between text (even indices) and paths (odd indices).
    for i, part in enumerate(parts):
        # Ignore empty strings that result from splitting, for example,
        # if the string starts or ends with a path tag.
        if not part:
            continue

        label = 'path' if i % 2 != 0 else 'text'
        result.append((part, label))

    return result


def is_base64_image_or_path(s: str) -> bool:
    """
    Determines whether the input string is a base64-encoded image string
    (with a data URI prefix) or a path to an image file.

    Returns True if s is a base64 image string or a plausible image file path,
    otherwise False.
    """
    import re
    import os

    # Check for base64 image data URI
    if isinstance(s, str) and s.strip().startswith("data:image/"):
        # Basic check for base64 data URI
        return True

    return False




def get_llm_dict_value(llm_feedback, key, text_feedback_translation_dict=None):
    if isinstance(key, list):
        return_list = []
        for k in key:
            if k in llm_feedback:
                if text_feedback_translation_dict is None:
                    # TODO: check type
                    return_list.append(cast_type_from_str(llm_feedback[k]))
                else:
                    if llm_feedback[k].upper() in text_feedback_translation_dict:
                        return return_list.append(text_feedback_translation_dict[llm_feedback[k].upper()])
                    else:
                        return "VALUE_ERROR"
        return return_list
    else:
        if key in llm_feedback:
            if text_feedback_translation_dict is None:
                return llm_feedback[key]
            else:
                try:
                    if llm_feedback[key].upper() in text_feedback_translation_dict:
                        return text_feedback_translation_dict[llm_feedback[key].upper()]
                    else:
                        return "VALUE_ERROR"
                except:
                    return "VALUE_ERROR"
        else:
            return "KEY_ERROR"







class LMFeedbackVerifier(ABC):
    def __init__(self, env, feedback_type, data_path, condition_list=[], purpose_str="", system_prompt="You are an expert game player", data_filter=None, no_log=False, no_load_data=False, no_check_correct=False, project_name=None, cache_manager=None):
        self.no_log = no_log
        if not no_log:
            condition_str = "_".join(condition_list)
            file_name = os.path.basename(data_path)
            import re
            match = re.search(r'\d+', file_name)
            group = match.group() if match else None
            if group is not None:
                distribution = int(group)
            else:
                distribution = "traverse"
            if project_name is not None:
                self.project_name = project_name
            else:
                self.project_name = f"{env}-{purpose_str}-{condition_str}-{feedback_type}-{distribution}"
            
            if wandb.run is not None:
                self.project_name = wandb.run.entity + "/" + wandb.run.project 
            
            
            print(self.project_name)

            weave.init(self.project_name)
        self.env = env
        self.purpose_str = purpose_str
        self.prompt_config_path = ""
        self.feedback_type = feedback_type
        self.data_path = data_path
        self.condition_list = condition_list
        self.data_filter = data_filter
        self.no_load_data = no_load_data
        self.no_check_correct = no_check_correct
        if not self.no_load_data:
            self.load_data()
            self.filter_data()
        self.feedback_type_to_translation_dict = None #Expected to be set in children classes
        self.obs_representation_extractor = None #Expected in children classes
        self.image_representation_extractor = None #Expected in children classes
        self.number_to_action_dict = None
        self.use_action_map_dict = True
        self.preference_to_number_dict = {
            "FIRST": 1,
            "SECOND": -1
        }
        self.response_to_numeric_dict = {
            "YES": 1,
            "NO": -1
        }
        self.feedback_type_to_verify_key_dict = \
        {
            "binary_feedback": "feedback",
            "preference": "preference",
            "action_advising": "action",
            "delta_action": "index",
            "goal_advising": "goal"
        }
        if feedback_type == "binary_feedback":
            self.feedback_type_to_translation_dict = self.response_to_numeric_dict
        if feedback_type == "preference":
            self.feedback_type_to_translation_dict = self.preference_to_number_dict
        if feedback_type == "action_advising":
            self.feedback_type_to_translation_dict = None
        self.delta_action_prompt_cot = ""
        self.continuous_action = False
        self.continuous_action_threshold = None
        self.goal_advising_base_prompt_cot = ""
        self.if_optimal_prompt_cot = ""
        self.action_advising_base_prompt_cot = ""

        self.preference_base_prompt_cot = ""

        self.goal_advising_base_prompt_cot = ""
        self.system_prompt = system_prompt

        #ICL prompts
        self.icl_prompt_goal_advising = ""
        self.icl_prompt_action_advising = ""
        self.icl_prompt_delta_action = ""
        self.icl_prompt_preference = ""
        self.icl_prompt_binary_feedback = ""


        # Explain prediction task
        self.task_prompt_dict = {
            "binary_feedback": "You will be asked to judge an action whether it is the optimal action, given a state. \n",
            "action_advising": "You will be asked to predict the best action given a state. \n",
            "preference": "You will be asked to compare two actions given a state and provide preference. \n",
            "goal_advising": "You will be asked to predict the best next state given a state. \n",
            "delta_action": "You will be asked to correct and action given a state and an action. \n"
        }


        self.explicit_thinking_guides = ""

        self.cache_manager = cache_manager
        if self.cache_manager is not None:
            self.get_response = cache_response(self.cache_manager)(self.get_response)

    def load_data(self):
        if self.data_path is not None:
            if os.path.exists(self.data_path):
                self.data = np.load(self.data_path, allow_pickle=True)
        else:
            print("No data path provided")
        print("Data loaded: ", len(self.data))

    def filter_data(self):
        if self.data_filter is not None and self.data_filter != "":
            list_of_filters = self.data_filter.split("=")
            for filter in list_of_filters:
                #String to function handles
                print("Apply data filter: ", filter)
                filter_function = getattr(utils, filter)
                self.data = filter_function(self.data)
        print("Data filtered: ", len(self.data))

    def verify_single_data_point(self, data_point, source, prompt, system_prompt, model, url, api_key, verify_key, tokeniser=None):
        return_dict = {}
        response = self.get_response(source=source, prompt=prompt, url=url, model=model, api_key=api_key,
                                     system_prompt=system_prompt)
        if isinstance(response, tuple):
            response, reasoning_content = response
            return_dict["ReasoningContext"] = reasoning_content
        if config.MY_DEBUG:
            print(prompt)
            print(response)
        return_dict = self.general_verify(data=data_point, response=response,
                                            verify_key=verify_key,
                                            text_feedback_translation_dict=self.feedback_type_to_translation_dict)

        if MY_DEBUG:
            print(return_dict)

        # if len(result_list) == 0:
        return_dict["Purpose"] = self.purpose_str
        return_dict["PromptConfigPath"] = self.prompt_config_path

        if CHECK_TOKEN_LENGTH and tokeniser is not None:
            tokens = tokeniser.tokenize(prompt)
            num_tokens = len(tokens)
            print("Prompt Length in token: ", num_tokens)

            tokens2 = tokeniser.tokenize(response)
            num_tokens2 = len(tokens2)
            print("Response Length in token: ", num_tokens2)
            return_dict["TokenLength"] = (num_tokens, num_tokens2)
        return_dict["Prompt"] = prompt

        return return_dict

    def check_prompt(self, max_item_num=-1, skip_neutral=False):

        if skip_neutral and self.feedback_type == "preference":
            processed_dataset = [i for i in self.data if i["feedback"] != 0]
        else:
            processed_dataset = self.data

        # np.random.seed(42)
        random_indexes = np.arange(len(processed_dataset))
        # np.random.shuffle(random_indexes)
        processed_dataset = [processed_dataset[i] for i in random_indexes]
        for index, i in enumerate(processed_dataset):
            if index > max_item_num and max_item_num > 0:
                break
            prompt = self.generate_prompt(i)
            # print("Data: \n", i)
            print("Feedback: ", i["feedback"])
            if "expert_actions" in i:
                print("Expert actions: ", i["expert_actions"])
            print("\n")
            print(prompt)
            print("\n")
            print("\n")


    def verify(self, round, source="ollama", url="http://localhost:11434/api/chat", model="Meta-Llama-3.1-8B-Instruct", max_item_num=-1, start_from="", uniform_sampling=False, skip_neutral=False, api_key="", pair_preference=False, numer_of_thread=1, order_preserving=False):
        if CHECK_TOKEN_LENGTH:
            from transformers import AutoTokenizer
            tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-70B")
        else:
            tokenizer = None
        system_prompt = self.system_prompt

        # Load or create a new log file
        if start_from != "":
            if start_from == "LATEST":
                save_path = get_latest_modified_file_with_prefix(PERSISTENT_DATA_PATH + "/results/{env}/{feedback_type}/{condition}/{model}/".format(
                    env=self.env, feedback_type=self.feedback_type, condition="_".join(self.condition_list) if len(self.condition_list) > 0 else "normal", model=model), prefix=os.path.basename(self.data_path) + "_")
                nothing_found = False
                if save_path == None:
                    nothing_found = True
                    save_path = PERSISTENT_DATA_PATH + "/results/{env}/{feedback_type}/{condition}/{model}/{datapath}_{created_time}.npy".format(
                        env=self.env, feedback_type=self.feedback_type, datapath=os.path.basename(self.data_path),
                        condition="_".join(self.condition_list) if len(self.condition_list) > 0 else "normal",
                        created_time=time.time(), model=model)
                    print("Nothing found. Save to: ", save_path)
                    os.makedirs(os.path.dirname(save_path), exist_ok=True)

            else:
                nothing_found = False
                save_path = PERSISTENT_DATA_PATH + "/results/{env}/{feedback_type}/{condition}/{model}/{datapath}_{created_time}_{purpose}.npy".format(
                    env=self.env, feedback_type=self.feedback_type, datapath=os.path.basename(self.data_path),
                    condition="_".join(self.condition_list) if len(self.condition_list) > 0 else "normal",
                    created_time=start_from, model=model, purpose=self.purpose_str)
            print("Start from: ", save_path)
        else:
            save_path = PERSISTENT_DATA_PATH + "/results/{env}/{feedback_type}/{condition}/{model}/{datapath}_{created_time}_{purpose}.npy".format(model=model, env=self.env, feedback_type=self.feedback_type, datapath=os.path.basename(self.data_path), condition= "_".join(self.condition_list) if len(self.condition_list) > 0 else "normal", created_time=time.time(), purpose=self.purpose_str)
            print("Save to: ", save_path)
            os.makedirs(os.path.dirname(save_path), exist_ok=True)  # Create directories if they don't exist
            nothing_found = True

        result_list = []
        start_index = 0
        if start_from != "" and not nothing_found:
            previous_results_list = np.load(save_path, allow_pickle=True)
            result_list = list(previous_results_list)
            start_index = int(len(previous_results_list)/round) + 1
            print("Start from: ", start_index)

        # Uniform sampling and shuffle
        if uniform_sampling:
            random_indexes_path = save_path + "uniform_indexes.npy"
            if start_from != "" and not nothing_found:
                random_indexes = np.load(random_indexes_path)
            else:
                random_indexes = np.random.choice(len(self.data), len(self.data), replace=False)
                np.save(random_indexes_path, random_indexes)

            processed_dataset = self.data[random_indexes]
        else:
            processed_dataset = self.data

        #Skip neutral preference
        if skip_neutral and self.feedback_type == "preference":
            processed_dataset = [i for i in processed_dataset if i["feedback"] != 0]


        # Ensure preferences come in pairs
        if pair_preference and self.feedback_type == "preference":
            new_data_set = []
            for i in processed_dataset:
                new_data_set.append(i)
                i_copy = i.copy()
                i_copy["feedback"] = -i["feedback"]
                i_copy["action1"], i_copy["action2"] = i_copy["action2"], i_copy["action1"]
                new_data_set.append(i_copy)
                # print(i, i_copy, "\n\n")
            processed_dataset = new_data_set

        possible_max_item_num = min(len(processed_dataset), max_item_num)
        if max_item_num > 0:
            dataset = processed_dataset[start_index:possible_max_item_num]
        else:
            dataset = processed_dataset[start_index:]



        print("Total number in dataset: ", len(self.data))
        print("Total number to check: ", len(dataset))
        bar = tqdm(dataset)
        total_correct = 0
        total_wrong = 0
        illegal_json_count = 0
        illegal_dict_output_count = 0

        feedback_type_to_verify_key_dict = self.feedback_type_to_verify_key_dict
        verify_key = feedback_type_to_verify_key_dict[self.feedback_type]

        if numer_of_thread == 1:
            for i in bar:
                prompt = self.generate_prompt(i)
                for j in range(round):

                    return_dict = self.verify_single_data_point(i, source, prompt, system_prompt, model, url, api_key, verify_key, tokenizer)
                    if MY_DEBUG:
                        print(return_dict)

                    result_list.append(return_dict)
                    if return_dict["Correct"]:
                        total_correct += 1
                    else:
                        total_wrong += 1

                    if not return_dict["JSONCorrect"]:
                        illegal_json_count += 1
                    if (not return_dict["KeyCorrect"]) or (not return_dict["ValueCorrect"]):
                        illegal_dict_output_count += 1
                    bar.set_description(
                        f"Accuracy: {total_correct / (total_correct + total_wrong)}, Illegal JSON: {illegal_json_count / (total_correct + total_wrong)}, Noise: {total_wrong / (total_correct + total_wrong)}, Illegal KeyValue: {illegal_dict_output_count/(total_correct+total_wrong)}")
                np.save(save_path, result_list)
        elif numer_of_thread > 1:
            #Multiple threads
            repeated_dataset = dataset * round if isinstance(dataset, list) else dataset.repeat(round)
            start_time = time.time()
            with (ThreadPoolExecutor(max_workers=numer_of_thread) as executor):
                # This implementation does not keep the order of the data points
                if not order_preserving:
                    future_to_data = {
                        executor.submit(self.verify_single_data_point, i, source, self.generate_prompt(i), system_prompt, model, url, api_key, verify_key,
                                        tokenizer): i for i in repeated_dataset}

                    for future in as_completed(future_to_data):
                        return_dict = future.result()
                        current_time = time.time()
                        result_list.append(return_dict)
                        if return_dict["Correct"]:
                            total_correct += 1
                        else:
                            total_wrong += 1

                        if not return_dict["JSONCorrect"]:
                            illegal_json_count += 1
                        if (not return_dict["KeyCorrect"]) or (not return_dict["ValueCorrect"]):
                            illegal_dict_output_count += 1

                        print(
                            f"Progress: {len(result_list)/len(repeated_dataset)}, time elapsed: {current_time - start_time}, time per data point: {(current_time - start_time)/len(result_list)} ETA: {(current_time - start_time)/len(result_list) * (len(repeated_dataset) - len(result_list))}"
                        )
                        print(
                            f"Accuracy: {total_correct / (total_correct + total_wrong)}, Illegal JSON: {illegal_json_count / (total_correct + total_wrong)}, Noise: {total_wrong / (total_correct + total_wrong)}, Illegal KeyValue: {illegal_dict_output_count / (total_correct + total_wrong)}"
                        )
                        if len(result_list) > 0  and len(result_list) % numer_of_thread == 0:
                            np.save(save_path, result_list)
                else:
                    # This implementation keeps the order of the data points
                    result_list = list(executor.map(
                        lambda i: self.verify_single_data_point(i, source, self.generate_prompt(i), system_prompt, model, url, api_key, verify_key, tokenizer),
                        repeated_dataset
                    ))
                    for i in range(len(result_list)):
                        return_dict = result_list[i]
                        if return_dict["Correct"]:
                            total_correct += 1
                        else:
                            total_wrong += 1

                        if not return_dict["JSONCorrect"]:
                            illegal_json_count += 1
                        if (not return_dict["KeyCorrect"]) or (not return_dict["ValueCorrect"]):
                            illegal_dict_output_count += 1

                    current_time = time.time()

                    print(
                        f"Progress: {len(result_list)/len(repeated_dataset)}, time elapsed: {current_time - start_time}, time per data point: {(current_time - start_time)/len(result_list)} ETA: {(current_time - start_time)/len(result_list) * (len(repeated_dataset) - len(result_list))}"
                    )
                    print(
                        f"Accuracy: {total_correct / (total_correct + total_wrong)}, Illegal JSON: {illegal_json_count / (total_correct + total_wrong)}, Noise: {total_wrong / (total_correct + total_wrong)}, Illegal KeyValue: {illegal_dict_output_count / (total_correct + total_wrong)}"
                    )
                    np.save(save_path, result_list)

            np.save(save_path, result_list)
        return result_list

    def llm_feedback_extract(self, llm_feedback):
        feedback_type_to_verify_key_dict = self.feedback_type_to_verify_key_dict
        verify_key = feedback_type_to_verify_key_dict[self.feedback_type]
        if "```json" in llm_feedback:
            response = llm_feedback.split("```json")[-1]
            response = response.split("```")[0]
        response = llm_feedback
        try:
            response = response.split("{")
            response = response[-1]
            response = "{" + response
            llm = json.loads(response)
            llm_feedback = get_llm_dict_value(llm, verify_key, self.feedback_type_to_translation_dict)
            return llm_feedback
        except Exception as e:
            print(e)
            return None

    def general_verify(self, data, response, verify_key, text_feedback_translation_dict):
        ret_dict = {"State": data, "Response": response}
        ret_dict["JSONCorrect"] = True

        ret_dict["ValueCorrect"] = True
        ret_dict["KeyCorrect"] = True
        ret_dict["ValueCorrect"] = True

        if not self.no_check_correct:
            expert_feedback = data["feedback"]
        else:
            expert_feedback = None


        if response == "" or response == None:
            ret_dict["Correct"] = False
            ret_dict["JSONCorrect"] = False
            ret_dict["LLMFeedback"] = None
            return ret_dict

        ## Deal ```json ``` format
        if "```json" in response:
            response = response.split("```json")[-1]
            response = response.split("```")[0]


        try:
            response = response.split("{")
            response = response[-1]
            response = "{" + response
            llm = json.loads(response)
        except Exception as e:
            print(e)
            ret_dict["Correct"] = False
            ret_dict["JSONCorrect"] = False
            return ret_dict

        llm_feedback = get_llm_dict_value(llm, verify_key, text_feedback_translation_dict)


        if isinstance(expert_feedback, bool):
            if expert_feedback:
                expert_feedback = 1
            else:
                expert_feedback = -1
        if config.MY_DEBUG:
            print("LLM Feedback",llm_feedback)
            print("Expert Feedback", expert_feedback)

        if llm_feedback == "KEY_ERROR" or llm_feedback == "VALUE_ERROR":
            ret_dict["Correct"] = False
            if llm_feedback == "KEY_ERROR":
                ret_dict["KeyCorrect"] = False
            else:
                ret_dict["KeyCorrect"] = True
            if llm_feedback == "VALUE_ERROR":
                ret_dict["ValueCorrect"] = False
            else:
                ret_dict["ValueCorrect"] = True
        else:
            if not self.no_check_correct:
                if self.continuous_action and self.feedback_type == "action_advising":
                    try:
                        import ast

                        llm_feedback = ast.literal_eval(llm_feedback)
                        if isinstance(llm_feedback, tuple):
                            #tuple to list
                            llm_feedback = list(llm_feedback)
                        for i in range(len(llm_feedback)):
                            if llm_feedback[i] == "open":
                                llm_feedback[i] = 1
                            if llm_feedback[i] == "close":
                                llm_feedback[i] = -1
                        # print(llm_feedback)
                        llm_feedback = np.asarray(llm_feedback).astype(float)
                        # print(llm_feedback)
                        distance = np.linalg.norm(llm_feedback - np.asarray(expert_feedback[0]), ord=2)
                        if distance < self.continuous_action_threshold:
                            ret_dict["Correct"] = True
                        else:
                            ret_dict["Correct"] = False
                    except Exception as e:
                        # print(e)
                        # print("LLM Feedback", llm_feedback)
                        # print("Expert Feedback", expert_feedback)
                        ret_dict["Correct"] = False
                        ret_dict["ValueCorrect"] = False
                else:
                    if llm_feedback == expert_feedback or (expert_feedback == 0 and self.feedback_type == "preference") or ((self.feedback_type == "action_advising" or self.feedback_type == "goal_advising") and llm_feedback in expert_feedback):
                        ret_dict["Correct"] = True
                    else:
                        ret_dict["Correct"] = False
            else:
                ret_dict["Correct"] = None
                ret_dict["LLMFeedback"] = llm_feedback

        return ret_dict

    def general_generate_prompt(self, data, use_action_map_dict=True):
        base_prompt = self.base_prompt

        if "egocentric" in self.condition_list:
            base_prompt = self.base_prompt_egocentric


        action_dict = self.number_to_action_dict


        def default_action_mapping(x):
            return str(x)

        if use_action_map_dict and self.number_to_action_dict is not None:
            action_map = lambda x: action_dict[x]
        else:
            action_map = default_action_mapping

        if_optimal_prompt_cot = self.if_optimal_prompt_cot
        action_advising_base_prompt_cot = self.action_advising_base_prompt_cot
        preference_base_prompt_cot = self.preference_base_prompt_cot
        delta_action_prompt_cot = self.delta_action_prompt_cot
        goal_advising_base_prompt_cot = self.goal_advising_base_prompt_cot
        if "explicit_action_by_action_cot" in self.condition_list:
            if_optimal_prompt_cot = self.if_optimal_prompt_explicit_action_by_action_cot
            action_advising_base_prompt_cot = self.action_advising_base_prompt_explicit_action_by_action_cot

        obs_representation = self.obs_representation_extractor(data, self.condition_list)
        if "ICL" in self.condition_list:
            if self.feedback_type == "binary_feedback":
                icl =  self.icl_prompt_binary_feedback
            if self.feedback_type == "preference":
                icl = self.icl_prompt_preference
            if self.feedback_type == "action_advising":
                icl = self.icl_prompt_action_advising
            if self.feedback_type == "delta_action":
                icl = self.icl_prompt_delta_action
            if self.feedback_type == "goal_advising":
                icl = self.icl_prompt_goal_advising
        else:
            icl = ""

        if "task_prompt" in self.condition_list:
            task_prompt = self.task_prompt_dict[self.feedback_type]
        else:
            task_prompt = ""
        if "explicit_thinking_guides" in self.condition_list:
            thinking_guidance = self.explicit_thinking_guides
        else:
            thinking_guidance = ""

        if self.feedback_type == "binary_feedback":
            final_string = (base_prompt + task_prompt + thinking_guidance + icl +
                            if_optimal_prompt_cot.replace("OBSREPRESENTATION", obs_representation)
                            .replace("ACTION", action_map(data["action"])))
        if self.feedback_type == "action_advising":
            final_string = (base_prompt + task_prompt + thinking_guidance + icl +
                            action_advising_base_prompt_cot.replace("OBSREPRESENTATION", obs_representation))
        if self.feedback_type == "preference":
            final_string = (base_prompt + task_prompt + thinking_guidance + icl +
                            preference_base_prompt_cot.replace("OBSREPRESENTATION", obs_representation)
                            .replace("ACTION1", action_map(data["action1"])).replace("ACTION2", action_map(data["action2"])))

        if self.feedback_type == "delta_action":
            final_string = (base_prompt + task_prompt + thinking_guidance + icl +
                            delta_action_prompt_cot.replace("OBSREPRESENTATION", obs_representation))

        if self.feedback_type == "goal_advising":
            final_string = (base_prompt + task_prompt + thinking_guidance + icl +
                            goal_advising_base_prompt_cot.replace("OBSREPRESENTATION", obs_representation))
            
        return final_string


    @abstractmethod
    def domain_specific_prompt_process(self, data, prompt):
        return prompt

    def generate_prompt(self, data):
        prompt = self.general_generate_prompt(data, use_action_map_dict=self.use_action_map_dict)
        final_prompt = self.domain_specific_prompt_process(data, prompt)

        if "image_observation" in self.condition_list:
            designated_image_obs_position_flag = False
            payload = []
            split_results = extract_and_split_paths(final_prompt)
            for i in range(len(split_results)):
                if split_results[i][1] == "path":
                    image_path = split_results[i][0]
                    if image_path == "IMAGEOBSERVATION":
                        # Designated image observation position
                        designated_image_obs_position_flag = True
                        image_list = self.image_representation_extractor(data, self.condition_list)
                        if isinstance(image_list, list):
                            for i in image_list:
                                image_base64_url = image_to_base64_data_url(i)
                                payload.append({
                                    "type": "image_url",
                                    "image_url": {"url": image_base64_url}
                                })
                        else:
                            image = image_to_base64_data_url(image_list)
                            payload.append({
                                "type": "image_url",
                                "image_url": {"url": image}
                            })
                    else:
                        # Normal image observation position, load image from path
                        # Check if the image path is a base64 image or a path to an image file
                        if is_base64_image_or_path(image_path):
                            image = image_path
                            payload.append({
                                "type": "image_url",
                                "image_url": {"url": image}
                            })
                        else:
                            image_path = config.PROMPT_CONFIG_PATH + "/" + image_path
                            # print(image_path)
                            image = image_to_base64_data_url(image_path)
                            payload.append({
                                "type": "image_url",
                                "image_url": {"url": image}
                            })
                else:
                    payload.append({
                        "type": "text",
                        "text": split_results[i][0]
                    })
            if designated_image_obs_position_flag == False:
                # If no designated image observation position, add a default image observation position
                payload.append({
                    "type": "text",
                    "text": "This is your image observation."
                })
                image_list = self.image_representation_extractor(data, self.condition_list)
                final_payload = payload
                if isinstance(image_list, list):
                    for i in image_list:
                        image_base64_url = image_to_base64_data_url(i)
                        final_payload.append({
                            "type": "image_url",
                            "image_url": {"url": image_base64_url}
                        })
                else:
                    image = image_to_base64_data_url(image_list)
                    final_payload.append({
                        "type": "image_url",
                        "image_url": {"url": image}
                    })
            else:
                final_payload = payload
            return final_payload


        return final_prompt

    # TODO deal with API Calls none response
    # TODO log other possible metrics from API calls
    #Deepseek api is not stable (and somehow they do not have rate limits) so retry for inifiite times
    @weave.op()
    # @retry(stop=stop_after_attempt(100), retry_error_callback=lambda x: None)
    @retry()
    def get_response(self, source, prompt, model="llama3.1:8b-instruct-fp16", url="http://localhost:11434/api/chat", api_key="", system_prompt="You are an expert game player"):
        """
        For reasoning models we also try to return the reasoning content
        :param source:
        :param prompt:
        :param model:
        :param url:
        :param api_key:
        :param system_prompt:
        :return:
        """
        if source == "ollama":
            return self.response_getter_ollama(prompt, model, url, system_prompt=system_prompt)
        elif source == "vec-inf":
            return self.response_getter_vec_inf(prompt, model, url, system_prompt=system_prompt)
        elif source == "together":
            return self.response_getter_together_ai(prompt, model, url, api_key, system_prompt=system_prompt)
        elif source == "groq":
            return self.response_getter_groq(prompt, model, url, api_key, system_prompt=system_prompt)
        elif source == "deepseek":
            return self.response_getter_deepseek(prompt, model, url, api_key, system_prompt=system_prompt)
        elif source == "nebius":
            return self.response_getter_nebius(prompt, model, url, api_key, system_prompt=system_prompt)
        elif source == "cerebras":
            return self.response_getter_cerebras(prompt, model, url, api_key, system_prompt=system_prompt)
        elif source == "vllm-serve":
            return self.response_getter_vllm_serve(prompt, model, url, api_key, system_prompt=system_prompt)
        elif source == "openai":
            return self.response_getter_openai(prompt, model, url, api_key, system_prompt=system_prompt)
        elif source == "TEST":
            # return "THIS IS A TEST FOR CACHE."
            return """{"action": "TURN LEFT", "reasoning": "Test reasoning."}"""
        else:
            return None


    def response_getter_ollama(self, prompt, model="llama3.1:8b-instruct-fp16", url="http://localhost:11434/api/chat",
                               return_original=False, system_prompt="You are an expert game player"):
        # Headers (if needed, e.g., for authorization)
        headers = {
            'Content-Type': 'application/json'  # This tells the API you're sending JSON
        }
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt}
        ]
        payload = {
            "model": model,
            # "prompt": "What color is the sky at different times of the day? Respond using JSON",
            "messages": messages,
            "stream": False
        }

        # Send the POST request
        response = requests.post(url, headers=headers, data=json.dumps(payload))

        # Check the response status code and the data received
        if response.status_code == 200:
            # print("Success!")
            if return_original:
                return (response.json())  # or response.text() if the response isn't JSON
            else:
                return response.json()["message"]["content"]
        else:
            # print(f"Failed! Status code: {response.status_code}")
            print(response)
            return None

    def response_getter_groq(self, prompt, model="llama3.1:8b-instruct-fp16", url="", api_key="", system_prompt="You are an expert game player"):
        if api_key != "":
            key = api_key
        else:
            key = os.environ.get("GROQ_API_KEY")
        client = Groq(
            api_key=key,
        )

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt}
        ]
        chat_completion = client.chat.completions.create(
            messages=messages,
            model=model,
        )
        return chat_completion.choices[0].message.content

    def response_getter_deepseek(self, prompt, model="deepseek-chat", url="", api_key="", system_prompt="You are an expert game player"):
        client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": prompt},
            ],
            stream=False
        )
        reasoning_content = response.choices[0].message.reasoning_content
        return response.choices[0].message.content, reasoning_content
    
    def response_getter_openai(self, prompt, model="PLACEHOLDER_FOR_ANOYNOMITY/openrouter-4o-mini", url="https://router.PLACEHOLDER_FOR_ANOYNOMITYlab.com", api_key="", system_prompt="You are an expert game player"):
        if config.MY_DEBUG:
            print("Model: ", model)
            print("URL: ", url)
            print("System Prompt: ", system_prompt)
            print("Prompt: ", prompt)
            print("--------------------------------")
        client = OpenAI(
            api_key=api_key,
            base_url=url
        )

        response = client.chat.completions.create(
            model=model, # model to send to the proxy
            messages = [
                {
                    "role": "system",
                    "content": system_prompt,
                },
                {
                    "role": "user",
                    "content": prompt
                }
            ]

        )
        if config.MY_DEBUG:
            print(response)
            print(response.choices[0].message.content)
            print("--------------------------------")
        return response.choices[0].message.content


    def response_getter_vec_inf(self, prompt, model="Meta-Llama-3.1-8B-Instruct", url="http://gpu005:8080/v1", system_prompt="You are an expert game player"
                               ):
        # print(url)
        # print(model)
        client = OpenAI(base_url=url, api_key="EMPTY")

        # Update the model path accordingly
        completion = client.chat.completions.create(
            model=model,
            messages=[
                {
                    "role": "system",
                    "content": system_prompt,
                },
                {"role": "user", "content": prompt},
            ],
        )
        return completion.choices[0].message.content

    def response_getter_together_ai(self, prompt, model="meta-llama/Meta-Llama-3-8B-Instruct-Turbo", url="", api_key="", system_prompt="You are an expert game player"):
        client = Together(api_key=api_key)

        completion = client.chat.completions.create(
            model=model,
            messages=[{
                    "role": "system",
                    "content": system_prompt,
                },
                {"role": "user", "content": prompt},
            ],
        )
        # print(prompt)
        # print(completion.choices[0].message.content)
        return completion.choices[0].message.content

    def response_getter_nebius(self, prompt, model="meta-llama/Meta-Llama-3.1-8B-Instruct", url="", api_key="", system_prompt="You are an expert game player"):
        client = OpenAI(
            base_url="https://api.studio.nebius.ai/v1/",
            api_key=api_key,
        )
        completion = client.chat.completions.create(
            model=model,
            messages=[{
                    "role": "system",
                    "content": system_prompt,
                },
                {"role": "user", "content": prompt},
            ],
        )
        return completion.choices[0].message.content

    def response_getter_cerebras(self, prompt, model="meta-llama/Meta-Llama-3.1-8B-Instruct", url="", api_key="", system_prompt="You are an expert game player"):
        client = Cerebras(
            api_key=api_key,
        )

        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "system",
                    "content": system_prompt,
                },
                {"role": "user", "content": prompt},
            ],
            model=model,
        )
        return chat_completion.choices[0].message.content

    def response_getter_vllm_serve(self, prompt, model="/h/PLACEHOLDER_FOR_ANOYNOMITYli/PLACEHOLDER_FOR_ANOYNOMITYlab_storage/models/QwQ-32B", url="http://localhost:8000/v1/chat/completions", api_key="", system_prompt="You are an expert game player"):
        headers = {
            "Content-Type": "application/json"
        }
        data = {
            "model": model,
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": prompt}
            ]
        }
        response = requests.post(url, headers=headers, data=json.dumps(data))
        print(response.status_code)
        return response.json()["choices"][0]["message"]["content"]

    def load_prompt_from_config(self, path):
        self.prompt_config = PromptConfig()
        self.prompt_config.from_yml(path)
        self.prompt_config_path = path
        print("Load config from: ", path)
        for key, value in self.prompt_config.__dict__.items():
            # If key in config and not None and in current object, set it
            if key in self.__dict__ and value is not None:
                setattr(self, key, value)