{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### openai test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output_1:\n",
      "In the realm of codes and lines that loop,\n",
      "There lies a concept both clever and astute.\n",
      "Recursion, a dance of function's delight,\n",
      "Infinite beauty, within a programmer's sight.\n",
      "\n",
      "Like a mirror reflecting its own reflection,\n",
      "Recursive functions call upon their own affection.\n",
      "A self-repeating cycle, a loop within a loop,\n",
      "Unraveling mysteries, like a hidden troupe.\n",
      "\n",
      "With elegance and grace, they unravel and unwind,\n",
      "Each iteration a journey of a unique kind.\n",
      "Through the depths of logic, they endlessly dive,\n",
      "A recursive melody where codes strive.\n",
      "\n",
      "Think of Russian dolls nested snug,\n",
      "Each a container, a tiny bug.\n",
      "One within another, a pattern unfolds,\n",
      "In programming tales, recursion holds.\n",
      "\n",
      "A function calls itself, a magic so surreal,\n",
      "Infinite possibilities, a coder's zeal.\n",
      "Through layers of abstraction, a story is told,\n",
      "In the enchanting world of recursive code.\n",
      "output_2:\n",
      "In the realm of code where logic flows,\n",
      "There lies a concept that programmers know,\n",
      "Recursion, a mystical, looping art,\n",
      "That captures minds and tugs at the heart.\n",
      "\n",
      "Like a fractal pattern, endlessly repeating,\n",
      "A function calls itself without needing\n",
      "To break the problem into smaller parts,\n",
      "It dances in circles, connecting its parts.\n",
      "\n",
      "First, a base case to end the cycle,\n",
      "Then, a recursive call with a meaningful style,\n",
      "Dividing the task into pieces so small,\n",
      "Until the solution emerges, standing tall.\n",
      "\n",
      "A magical spell, a powerful tool,\n",
      "In the coder's arsenal, a golden jewel,\n",
      "It unravels puzzles with elegant grace,\n",
      "In the infinite loops of cyberspace.\n",
      "\n",
      "So embrace the beauty of recursion's dance,\n",
      "A poetic expression, a programming trance,\n",
      "As it weaves through algorithms, intricate and fine,\n",
      "In the wondrous world of code divine.\n"
     ]
    }
   ],
   "source": [
    "import ast\n",
    "import os\n",
    "import sys\n",
    "import json\n",
    "from openai import OpenAI\n",
    "from openai import AsyncOpenAI\n",
    "\n",
    "client = OpenAI(api_key='sk-rifpc-2Gg7xjJ4qrwzWY7hUhZKT3BlbkFJBkz9CHkx9LkVsSciz9Tg')\n",
    "\n",
    "client_async = AsyncOpenAI(api_key='sk-rifpc-2Gg7xjJ4qrwzWY7hUhZKT3BlbkFJBkz9CHkx9LkVsSciz9Tg')\n",
    "\n",
    "completion_1 = client.chat.completions.create(\n",
    "  model=\"gpt-3.5-turbo\",\n",
    "  messages=[\n",
    "    {\"role\": \"system\", \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"},\n",
    "    {\"role\": \"user\", \"content\": \"Compose a poem that explains the concept of recursion in programming.\"}\n",
    "  ]\n",
    ")\n",
    "\n",
    "completion_2 = await client_async.chat.completions.create(\n",
    "  model=\"gpt-3.5-turbo\",\n",
    "  messages=[\n",
    "    {\"role\": \"system\", \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"},\n",
    "    {\"role\": \"user\", \"content\": \"Compose a poem that explains the concept of recursion in programming.\"}\n",
    "  ]\n",
    ")\n",
    "\n",
    "print(\"output_1:\")\n",
    "print(completion_1.choices[0].message.content)\n",
    "print(\"output_2:\")\n",
    "print(completion_2.choices[0].message.content)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Eval test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import ast\n",
    "import shutil\n",
    "import subprocess\n",
    "import tempfile\n",
    "import time\n",
    "import inspect\n",
    "import json\n",
    "import re\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import asyncio\n",
    "\n",
    "from openai import OpenAI\n",
    "from openai import AsyncOpenAI\n",
    "\n",
    "client = OpenAI(api_key='sk-rifpc-2Gg7xjJ4qrwzWY7hUhZKT3BlbkFJBkz9CHkx9LkVsSciz9Tg')\n",
    "\n",
    "client_async = AsyncOpenAI(api_key='sk-rifpc-2Gg7xjJ4qrwzWY7hUhZKT3BlbkFJBkz9CHkx9LkVsSciz9Tg')\n",
    "\n",
    "class Planner_agent:\n",
    "    def __init__(self, api_key, max_tokens=4096, store_path='store_path', model='gpt-3.5-turbo', embedding_model=\"text-embedding-3-small\", json_file_paths=[], RAG_config=None):\n",
    "        self.client = OpenAI(api_key=api_key)\n",
    "        self.client_async = AsyncOpenAI(api_key=api_key)\n",
    "        self.max_tokens = max_tokens\n",
    "        self.store_path = store_path\n",
    "        self.model = model\n",
    "        self.embedding_model = embedding_model\n",
    "        if json_file_paths:\n",
    "            self.json_df = self.load_json_to_dataframe(json_file_paths)\n",
    "        else:\n",
    "            self.json_df = None\n",
    "        self.embed_df = None\n",
    "        self.RAG_config = RAG_config\n",
    "\n",
    "    def get_RAG_prompt(self, json_df, task_text):\n",
    "\n",
    "        if json_df is None:\n",
    "            raise ValueError(\"json_df is None. Please provide a valid dataframe.\")\n",
    "        \n",
    "        function_name = json_df[json_df['type'] == 'function']['name'].tolist()\n",
    "        class_name = json_df[json_df['type'] == 'class']['name'].tolist()\n",
    "        evaluation_name = json_df[json_df['type'] == 'evaluation']['name'].tolist()\n",
    "        function_description = json_df[json_df['type'] == 'function']['description'].tolist()\n",
    "        class_description = json_df[json_df['type'] == 'class']['description'].tolist()\n",
    "        evaluation_description = json_df[json_df['type'] == 'evaluation']['description'].tolist()\n",
    "\n",
    "        function_text = \"\\n\\n\".join([f\"Function: {name}\\nDescription: {desc}\" for name, desc in zip(function_name, function_description)])\n",
    "        class_text = \"\\n\\n\".join([f\"Class: {name}\\nDescription: {desc}\" for name, desc in zip(class_name, class_description)])\n",
    "        evaluation_text = \"\\n\\n\".join([f\"Evaluation: {name}\\nDescription: {desc}\" for name, desc in zip(evaluation_name, evaluation_description)])\n",
    "        \n",
    "        prompt_message = f\"\"\"\n",
    "You are given a task description, a list of function descriptions, a list of class descriptions and a list of evaluation modules. Your goal is to select the functions and classes that are directly relevant to the task description and suitable evaluation modules for the task. Please output the selected functions, classes and evaluation modules in the following format:\n",
    "\n",
    "Selected Functions:\n",
    "- FunctionName1: Reason for selection\n",
    "- FunctionName2: Reason for selection\n",
    "\n",
    "Selected Classes:\n",
    "- ClassName1: Reason for selection\n",
    "- ClassName2: Reason for selection\n",
    "\n",
    "Selected Evaluation Modules:\n",
    "- EvalName1: Reason for selection\n",
    "- EvalName2: Reason for selection\n",
    "\n",
    "Only include functions, classes and modules that are relevant to the task description, if no suitable selection, leave it empty.\n",
    "\n",
    "### Task Description:\n",
    "{task_text}\n",
    "\n",
    "### Available Functions:\n",
    "{function_text}\n",
    "\n",
    "### Available Classes:\n",
    "{class_text}\n",
    "\n",
    "### Available Evaluation Modules:\n",
    "{evaluation_text}\n",
    "\n",
    "Output the selected functions, classes and modules according to the format above.\"\"\"\n",
    "\n",
    "        return prompt_message\n",
    "    \n",
    "    def extract_selected_functions_classes(self, llm_output):\n",
    "        functions_pattern = r\"Selected Functions:\\n((?:- .+\\n)*)\"\n",
    "        classes_pattern = r\"Selected Classes:\\n((?:- .+\\n)*)\"\n",
    "        evaluation_pattern = r\"Selected Evaluation Modules:\\n((?:- .+\\n)*)\"\n",
    "\n",
    "        functions_match = re.search(functions_pattern, llm_output)\n",
    "        classes_match = re.search(classes_pattern, llm_output)\n",
    "        evaluation_match = re.search(evaluation_pattern, llm_output)\n",
    "        \n",
    "        if functions_match:\n",
    "            selected_functions = re.findall(r\"- (\\w+):\", functions_match.group(1))\n",
    "        else:\n",
    "            selected_functions = []\n",
    "        \n",
    "        if classes_match:\n",
    "            selected_classes = re.findall(r\"- (\\w+):\", classes_match.group(1))\n",
    "        else:\n",
    "            selected_classes = []\n",
    "\n",
    "        if evaluation_match:\n",
    "            selected_evaluations = re.findall(r\"- (\\w+):\", evaluation_match.group(1))\n",
    "        else:\n",
    "            selected_evaluations = []\n",
    "        \n",
    "        return {\n",
    "            \"functions\": selected_functions,\n",
    "            \"classes\": selected_classes,\n",
    "            \"evaluations\": selected_evaluations\n",
    "        }\n",
    "\n",
    "    def filter_df_by_selected_functions_classes(self, df, selected_dict):\n",
    "        selected_functions = selected_dict['functions']\n",
    "        selected_classes = selected_dict['classes']\n",
    "        selected_evaluations = selected_dict['evaluations']\n",
    "        \n",
    "        # Filter based on the type and name matching\n",
    "        filtered_df = df[\n",
    "            ((df['type'] == 'function') & df['name'].isin(selected_functions)) |\n",
    "            ((df['type'] == 'class') & df['name'].isin(selected_classes)) |\n",
    "            ((df['type'] == 'evaluation') & df['name'].isin(selected_evaluations))\n",
    "        ]\n",
    "        \n",
    "        return filtered_df\n",
    "    \n",
    "    def get_RAG_result(self, task_text, json_df, print_prompt=False):\n",
    "        prompt_message = self.get_RAG_prompt(json_df, task_text)\n",
    "        if print_prompt:\n",
    "            print(\"RAG Prompt:\")\n",
    "            print(prompt_message)\n",
    "        llm_output = self.get_llm_response(prompt_message)\n",
    "        selected_functions_classes = self.extract_selected_functions_classes(llm_output)\n",
    "        filtered_df = self.filter_df_by_selected_functions_classes(json_df, selected_functions_classes)\n",
    "        return filtered_df \n",
    "    \n",
    "    async def perform_RAG_step(self, RAG_method, task_text, print_prompt):\n",
    "        print(\"RAG step started.\")\n",
    "        \n",
    "        if RAG_method == \"EMBEDDING\":\n",
    "            self.embed_df = await self.generate_embeddings(self.json_df, self.client_async, model=self.embedding_model)\n",
    "            print(\"Embeddings generated.\")\n",
    "            filtered_df = await self.filter_by_similarity(self.embed_df, task_text, self.client_async, **self.RAG_config)\n",
    "            eval_filtered_df = filtered_df[filtered_df['type'] == 'evaluation'].copy(deep=True)\n",
    "\n",
    "        elif RAG_method == \"AGENT\":\n",
    "            filtered_df = self.get_RAG_result(task_text, self.json_df, print_prompt)\n",
    "            eval_filtered_df = filtered_df[filtered_df['type'] == 'evaluation'].copy(deep=True)\n",
    "\n",
    "        else:\n",
    "            raise ValueError(\"Invalid RAG_method. Must be 'EMBEDDING' or 'AGENT'.\")\n",
    "        \n",
    "        return filtered_df, eval_filtered_df\n",
    "\n",
    "    async def get_embedding_async(self, text, client_async, model=\"text-embedding-3-small\"):\n",
    "        text = text.replace(\"\\n\", \" \")\n",
    "        response = await client_async.embeddings.create(input=text, model=model)\n",
    "        return response.data[0].embedding\n",
    "\n",
    "    async def generate_embeddings(self, df, client_async, model=\"text-embedding-3-small\"):\n",
    "        if 'description_embedding' not in df.columns:\n",
    "            df['description_embedding'] = None\n",
    "        if 'keywords_embedding' not in df.columns:\n",
    "            df['keywords_embedding'] = None\n",
    "\n",
    "        descriptions_to_embed_indices = df[df['description_embedding'].isna()].index\n",
    "        descriptions_to_embed = df.loc[descriptions_to_embed_indices, 'description'].tolist()\n",
    "        \n",
    "        keywords_to_embed_indices = df[df['keywords_embedding'].isna()].index\n",
    "        keywords_to_embed = df.loc[keywords_to_embed_indices, 'keywords'].tolist()\n",
    "\n",
    "        description_embeddings = await asyncio.gather(\n",
    "            *[self.get_embedding_async(desc, client_async, model) for desc in descriptions_to_embed]\n",
    "        )\n",
    "\n",
    "        for idx, embedding in zip(descriptions_to_embed_indices, description_embeddings):\n",
    "            df.at[idx, 'description_embedding'] = embedding\n",
    "\n",
    "        async def get_average_keyword_embedding(keywords):\n",
    "            keyword_embeddings = await asyncio.gather(\n",
    "                *[self.get_embedding_async(kw, client_async, model) for kw in keywords]\n",
    "            )\n",
    "            # print(len(keyword_embeddings))\n",
    "            # return np.mean(keyword_embeddings, axis=0).tolist()\n",
    "            return keyword_embeddings\n",
    "\n",
    "        keywords_embeddings = await asyncio.gather(\n",
    "            *[get_average_keyword_embedding(kw_list) for kw_list in keywords_to_embed]\n",
    "        )\n",
    "\n",
    "        for idx, embedding in zip(keywords_to_embed_indices, keywords_embeddings):\n",
    "            df.at[idx, 'keywords_embedding'] = embedding\n",
    "\n",
    "        return df\n",
    "\n",
    "    def cosine_similarity(self, a, b):\n",
    "        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))\n",
    "\n",
    "    async def filter_by_similarity(self, origin_embed_df, input_text, client_async, target = \"both\", target_type=[\"function\", \"class\"], weight=0.5, keywords_score=\"average\", min_threshold=0.0, min_num=1, keep_scores=False, model=\"text-embedding-3-small\"):\n",
    "\n",
    "        embed_df = origin_embed_df.copy(deep=True)\n",
    "        if isinstance(target_type, str):\n",
    "            target_type = [target_type]\n",
    "        embed_df = embed_df[embed_df['type'].isin(target_type)]\n",
    "\n",
    "        # Generate embedding for the input text\n",
    "        input_embedding = await self.get_embedding_async(input_text, client_async, model)\n",
    "        \n",
    "        # Ensure keywords_embedding is a list of lists\n",
    "        def ensure_list_of_lists(embedding):\n",
    "            return embedding if isinstance(embedding, list) and all(isinstance(i, list) for i in embedding) else [embedding]\n",
    "\n",
    "        # Define the function to calculate similarity based on the target\n",
    "        def calculate_similarity(row):\n",
    "            if target == \"description\":\n",
    "                return self.cosine_similarity(input_embedding, row['description_embedding'])\n",
    "            elif target == \"keywords\":\n",
    "                keyword_embeddings = ensure_list_of_lists(row['keywords_embedding'])\n",
    "                if isinstance(keywords_score, int):\n",
    "                    keyword_similarities = [self.cosine_similarity(input_embedding, kw_emb) for kw_emb in keyword_embeddings]\n",
    "                    top_k_similarities = sorted(keyword_similarities, reverse=True)[:keywords_score]\n",
    "                    return np.mean(top_k_similarities)\n",
    "                else:\n",
    "                    return np.mean([self.cosine_similarity(input_embedding, kw_emb) for kw_emb in keyword_embeddings])\n",
    "            elif target == \"both\":\n",
    "                description_similarity = self.cosine_similarity(input_embedding, row['description_embedding'])\n",
    "                keyword_embeddings = ensure_list_of_lists(row['keywords_embedding'])\n",
    "                if isinstance(keywords_score, int):\n",
    "                    keyword_similarities = [self.cosine_similarity(input_embedding, kw_emb) for kw_emb in keyword_embeddings]\n",
    "                    top_k_similarities = sorted(keyword_similarities, reverse=True)[:keywords_score]\n",
    "                    keywords_similarity = np.mean(top_k_similarities)\n",
    "                else:\n",
    "                    keywords_similarity = np.mean([self.cosine_similarity(input_embedding, kw_emb) for kw_emb in keyword_embeddings])\n",
    "                return weight * description_similarity + (1 - weight) * keywords_similarity\n",
    "\n",
    "        # Calculate similarity for each row\n",
    "        embed_df['similarity'] = embed_df.apply(calculate_similarity, axis=1)\n",
    "        \n",
    "        # Filter rows based on min_threshold\n",
    "        filtered_df = embed_df[embed_df['similarity'] >= min_threshold]\n",
    "        \n",
    "        # Ensure at least min_num rows are selected\n",
    "        if len(filtered_df) < min_num:\n",
    "            filtered_df = embed_df.nlargest(min_num, 'similarity')\n",
    "        \n",
    "        # Drop the similarity column before returning the result\n",
    "        if not keep_scores:\n",
    "            filtered_df = filtered_df.drop(columns=['similarity'])\n",
    "\n",
    "        return filtered_df\n",
    "\n",
    "    def extract_name_from_code(self, code_str, code_type):\n",
    "        if code_type == \"function\":\n",
    "            match = re.match(r\"def (\\w+)\\(\", code_str)\n",
    "        elif code_type == \"class\":\n",
    "            match = re.match(r\"class (\\w+)\\(\", code_str)\n",
    "        elif code_type == \"evaluation\":\n",
    "            match = re.match(r\"def (\\w+)\\(\", code_str)\n",
    "        else:\n",
    "            return None\n",
    "        if match:\n",
    "            return match.group(1)\n",
    "        return None\n",
    "\n",
    "    def find_missing_entities(self, base_df, function_list, class_list, evaluation_list):\n",
    "        missing_functions = []\n",
    "        missing_classes = []\n",
    "        missing_evaluations = []\n",
    "        \n",
    "        # Extract names from the provided lists\n",
    "        function_names = [self.extract_name_from_code(code, \"function\") for code in function_list]\n",
    "        class_names = [self.extract_name_from_code(code, \"class\") for code in class_list]\n",
    "        evaluation_names = [self.extract_name_from_code(code, \"evaluation\") for code in evaluation_list]\n",
    "        \n",
    "        # Check which names are not in the dataframe\n",
    "        for name in function_names:\n",
    "            if name and not ((base_df['type'] == 'function') & (base_df['name'] == name)).any():\n",
    "                missing_functions.append(name)\n",
    "        \n",
    "        for name in class_names:\n",
    "            if name and not ((base_df['type'] == 'class') & (base_df['name'] == name)).any():\n",
    "                missing_classes.append(name)\n",
    "        \n",
    "        for name in evaluation_names:\n",
    "            if name and not ((base_df['type'] == 'evaluation') & (base_df['name'] == name)).any():\n",
    "                missing_evaluations.append(name)\n",
    "        \n",
    "        return missing_functions, missing_classes, missing_evaluations\n",
    "\n",
    "    def extract_json_to_dict(self, gpt_output):\n",
    "        json_pattern = re.compile(r'```json(.*?)```', re.DOTALL)\n",
    "        match = json_pattern.search(gpt_output)\n",
    "        \n",
    "        if not match:\n",
    "            raise ValueError(\"No JSON format found in the provided string.\")\n",
    "        \n",
    "        json_str = match.group(1).strip()\n",
    "        \n",
    "        try:\n",
    "            json_dict = json.loads(json_str)\n",
    "        except json.JSONDecodeError as e:\n",
    "            raise ValueError(f\"Error decoding JSON: {e}\")\n",
    "        \n",
    "        return json_dict\n",
    "\n",
    "    def save_dict_to_json_file(self, data, file_path):\n",
    "        with open(file_path, 'w', encoding='utf-8') as f:\n",
    "            json.dump(data, f, ensure_ascii=False, indent=4)\n",
    "\n",
    "    def append_content_to_dict(self, base_dict, function_list, class_list):\n",
    "        def extract_name_from_string(code):\n",
    "            # Regex patterns to match class and function definitions\n",
    "            class_pattern = r'class\\s+(\\w+)\\s*[\\(:]'\n",
    "            function_pattern = r'def\\s+(\\w+)\\s*\\('\n",
    "\n",
    "            # Search for class name\n",
    "            class_match = re.search(class_pattern, code)\n",
    "            if class_match:\n",
    "                return class_match.group(1)\n",
    "\n",
    "            # Search for function name\n",
    "            function_match = re.search(function_pattern, code)\n",
    "            if function_match:\n",
    "                return function_match.group(1)\n",
    "\n",
    "            return None\n",
    "\n",
    "        def append_or_overwrite_content(item_list, string_list):\n",
    "            for s in string_list:\n",
    "                name = extract_name_from_string(s)\n",
    "                # print(name)\n",
    "                exist_item = False\n",
    "                if name:\n",
    "                    for item in item_list:\n",
    "                        if item['name'] == name:\n",
    "                            item['content'] = s\n",
    "                            exist_item = True\n",
    "                            break\n",
    "        \n",
    "        append_or_overwrite_content(base_dict['functions'], function_list)\n",
    "        append_or_overwrite_content(base_dict['classes'], class_list)\n",
    "\n",
    "        return base_dict\n",
    "\n",
    "    def get_json_data(self, input_file, openai_client, model='gpt-4o-mini', save_path=None):\n",
    "\n",
    "        if isinstance(input_file, str):\n",
    "            file_path = input_file\n",
    "            functions, classes, modules = load_file(file_path)\n",
    "            evaluations = []\n",
    "        elif isinstance(input_file, dict):\n",
    "            file_path = \"default_path.py\"\n",
    "            functions = input_file.get(\"functions\", [])\n",
    "            classes = input_file.get(\"classes\", [])\n",
    "            modules = input_file.get(\"modules\", [])\n",
    "            evaluations = input_file.get(\"evaluation\", [])\n",
    "\n",
    "        modules_str = '\\n'.join(modules)\n",
    "        functions_str = '\\n\\n'.join(functions)\n",
    "        classes_str = '\\n\\n'.join(classes)\n",
    "        evaluations_str = '\\n\\n'.join(evaluations)\n",
    "\n",
    "        # use openai to generate a json file that contains the description of functions and classes, also include the keywords, file path and related modules\n",
    "\n",
    "        prompt = f\"\"\"Generate a JSON file that describes the functions and classes in the given Python file. The file is located at {file_path}. \n",
    "\n",
    "        The file contains the following functions:\n",
    "        {functions_str}\n",
    "\n",
    "        The file contains the following classes:\n",
    "        {classes_str}\n",
    "\n",
    "        The file contains the following evaluation modules:\n",
    "        {evaluations_str}\n",
    "\n",
    "        The JSON file should include the following keywords: type, name, file_path, description and keywords.\n",
    "        Where the type is either 'function', 'class' or 'evaluation', name are the function or class name, file_path is the path to the file, and description is a brief description of the function or class, and keywords are the keywords that describe the function or class.\n",
    "        This JSON file will be used to generate documentation for the functions and classes in the given Python file and will be used for an RAG system to search for functions and classes based on keywords.\n",
    "        Note that the functions and classes should be stored into different lists in the JSON file.\n",
    "        The JSON's structure should be like this:\n",
    "        {{\n",
    "            \"functions\": [\n",
    "                {{\n",
    "                    \"type\": \"function\",\n",
    "                    \"name\": \"function_name\",\n",
    "                    \"file_path\": \"file_path\",\n",
    "                    \"description\": \"function_description\",\n",
    "                    \"keywords\": [\"keyword1\", \"keyword2\"]\n",
    "                }},\n",
    "                ...\n",
    "            ],\n",
    "            \"classes\": [\n",
    "                {{\n",
    "                    \"type\": \"class\",\n",
    "                    \"name\": \"class_name\",\n",
    "                    \"file_path\": \"file_path\",\n",
    "                    \"description\": \"class_description\",\n",
    "                    \"keywords\": [\"keyword1\", \"keyword2\"]\n",
    "                }},\n",
    "                ...\n",
    "            ],\n",
    "            \"evaluation\": [\n",
    "                {{\n",
    "                    \"type\": \"evaluation\",\n",
    "                    \"name\": \"evaluation_name\",\n",
    "                    \"file_path\": \"file_path\",\n",
    "                    \"description\": \"evaluation_description\",\n",
    "                    \"keywords\": [\"keyword1\", \"keyword2\"]\n",
    "                }},\n",
    "                ...\n",
    "            ]\n",
    "        }}\n",
    "        \"\"\"\n",
    "\n",
    "        completion = openai_client.chat.completions.create(\n",
    "            model=model,\n",
    "            messages=[\n",
    "                {\"role\": \"user\", \"content\": prompt}\n",
    "            ]\n",
    "            )\n",
    "        \n",
    "        output_json = self.extract_json_to_dict(completion.choices[0].message.content)\n",
    "        output_json = self.append_content_to_dict(output_json, functions, classes)\n",
    "\n",
    "        if save_path:\n",
    "            self.save_dict_to_json_file(output_json, save_path)\n",
    "\n",
    "        return output_json\n",
    "\n",
    "    def load_json_to_dataframe(self, json_inputs):\n",
    "\n",
    "        if isinstance(json_inputs, str):\n",
    "            json_inputs = [json_inputs]\n",
    "        elif isinstance(json_inputs, dict):\n",
    "            json_inputs = [json_inputs]\n",
    "\n",
    "        combined_data = []\n",
    "        for json_input in json_inputs:\n",
    "            if isinstance(json_input, str):\n",
    "                with open(json_input, 'r') as file:\n",
    "                    data = json.load(file)\n",
    "            elif isinstance(json_input, dict):\n",
    "                data = json_input\n",
    "            else:\n",
    "                print(type(json_input))\n",
    "                raise ValueError(\"Invalid input type. Must be a string (file path) or a dictionary.\")\n",
    "            \n",
    "            eval_modules = data.get(\"evaluation\", [])\n",
    "            functions = data.get(\"functions\", [])\n",
    "            classes = data.get(\"classes\", [])\n",
    "            \n",
    "            combined_data.extend(functions+classes+eval_modules)\n",
    "            \n",
    "        \n",
    "        df = pd.DataFrame(combined_data)\n",
    "        \n",
    "        return df\n",
    "\n",
    "    def files_to_json_to_dataframe(self, file_paths, openai_client, model='gpt-4o-mini', save_path=None):\n",
    "        if isinstance(file_paths, str):\n",
    "            file_paths = [file_paths]\n",
    "\n",
    "        json_data_list = []\n",
    "        for file_path_i in file_paths:\n",
    "            json_data_list.append(self.get_json_data(file_path_i, openai_client, model, save_path))\n",
    "\n",
    "        return self.load_json_to_dataframe(json_data_list)\n",
    "\n",
    "    def get_llm_response(self, text):\n",
    "        completion = self.client.chat.completions.create(\n",
    "            model=self.model,\n",
    "            max_tokens=self.max_tokens,\n",
    "            messages=[\n",
    "                {\"role\": \"user\", \"content\": text}\n",
    "            ]\n",
    "        )\n",
    "\n",
    "        return completion.choices[0].message.content\n",
    "    \n",
    "    async def get_llm_response_async(self, text):\n",
    "        completion = await self.client_async.chat.completions.create(\n",
    "            model=self.model,\n",
    "            max_tokens=self.max_tokens,\n",
    "            messages=[\n",
    "                {\"role\": \"user\", \"content\": text}\n",
    "            ]\n",
    "        )\n",
    "\n",
    "        return completion.choices[0].message.content\n",
    "\n",
    "    def extract_code_from_gpt_output(self, gpt_output):\n",
    "        lines = gpt_output.splitlines()\n",
    "        in_code_block = False\n",
    "        pure_code = []\n",
    "        have_code_block = False\n",
    "        missing_functions = []\n",
    "        missing_classes = []\n",
    "        collecting_functions = False\n",
    "        collecting_classes = False\n",
    "\n",
    "        for line in lines:\n",
    "            # Check for Missing Functions section\n",
    "            if line.startswith(\"Missing Functions:\"):\n",
    "                collecting_functions = True\n",
    "                collecting_classes = False\n",
    "                missing_functions.append(line)\n",
    "                continue\n",
    "\n",
    "            # Check for Missing Classes section\n",
    "            if line.startswith(\"Missing Classes:\"):\n",
    "                collecting_classes = True\n",
    "                collecting_functions = False\n",
    "                missing_classes.append(line)\n",
    "                continue\n",
    "\n",
    "            # Collect lines under Missing Functions\n",
    "            if collecting_functions:\n",
    "                if line.strip() and not line.startswith(\"-\"):\n",
    "                    collecting_functions = False\n",
    "                elif line.strip() == \"- None\" or line.strip() == \"None\":\n",
    "                    missing_functions = []\n",
    "                    collecting_functions = False\n",
    "                else:\n",
    "                    missing_functions.append(line)\n",
    "                continue\n",
    "\n",
    "            # Collect lines under Missing Classes\n",
    "            if collecting_classes:\n",
    "                if line.strip() and not line.startswith(\"-\"):\n",
    "                    collecting_classes = False\n",
    "                elif line.strip() == \"- None\" or line.strip() == \"None\":\n",
    "                    missing_classes = []\n",
    "                    collecting_classes = False\n",
    "                else:\n",
    "                    missing_classes.append(line)\n",
    "                continue\n",
    "\n",
    "            # Check for the start of a code block\n",
    "            if line.startswith(\"```\"):\n",
    "                in_code_block = not in_code_block\n",
    "                have_code_block = True\n",
    "            # If inside a code block, collect the code lines\n",
    "            elif in_code_block:\n",
    "                pure_code.append(line)\n",
    "\n",
    "        # If Missing Functions or Missing Classes are found and are not empty, return them\n",
    "        if missing_functions or missing_classes:\n",
    "            return \"\\n\".join(missing_functions + missing_classes), \"RAG\"\n",
    "\n",
    "        # If no code block or missing sections are found, return the original gpt_output\n",
    "        if not have_code_block:\n",
    "            return gpt_output, None\n",
    "\n",
    "        return \"\\n\".join(pure_code), \"code\"\n",
    "\n",
    "    def load_file(self, path):\n",
    "        #debug\n",
    "        with open(path, 'r') as file:\n",
    "            code = file.read()\n",
    "\n",
    "        tree = ast.parse(code)\n",
    "\n",
    "        functions = []\n",
    "        classes = []\n",
    "        modules = []\n",
    "\n",
    "        for node in tree.body:\n",
    "            if isinstance(node, ast.FunctionDef):\n",
    "                functions.append(ast.unparse(node))\n",
    "            elif isinstance(node, ast.ClassDef):\n",
    "                classes.append(ast.unparse(node))\n",
    "            elif isinstance(node, ast.Import):\n",
    "                for alias in node.names:\n",
    "                    modules.append(f\"import {alias.name}\")\n",
    "            elif isinstance(node, ast.ImportFrom):\n",
    "                module_name = node.module\n",
    "                for alias in node.names:\n",
    "                    if alias.asname:\n",
    "                        modules.append(f\"from {module_name} import {alias.name} as {alias.asname}\")\n",
    "                    else:\n",
    "                        modules.append(f\"from {module_name} import {alias.name}\")\n",
    "\n",
    "        return functions, classes, modules\n",
    "    \n",
    "    def get_planner_prompt(self, task_text, paths, print_prompt=False, can_use_RAG = False):\n",
    "\n",
    "        if isinstance(paths, str):\n",
    "            paths = [paths]\n",
    "\n",
    "        if isinstance(paths, list):\n",
    "            functions = []\n",
    "            classes = []\n",
    "            modules = []\n",
    "            for path in paths:\n",
    "                functions_, classes_, modules_ = self.load_file(path)\n",
    "                functions.extend(functions_)\n",
    "                classes.extend(classes_)\n",
    "                modules.extend(modules_)\n",
    "\n",
    "            modules_str = '\\n'.join(modules)\n",
    "            functions_str = '\\n\\n'.join(functions)\n",
    "            classes_str = '\\n\\n'.join(classes)\n",
    "        # if the input is a data frame\n",
    "        elif isinstance(paths, pd.DataFrame):\n",
    "            functions = paths[paths['type'] == 'function']['content'].tolist()\n",
    "            classes = paths[paths['type'] == 'class']['content'].tolist()\n",
    "            functions_str = '\\n\\n'.join(functions)\n",
    "            classes_str = '\\n\\n'.join(classes)\n",
    "            modules_str = \"\"\n",
    "        else:\n",
    "            raise ValueError(f\"Invalid input type. Must be a string (file path) or a dataframe. Got {type(paths)}\")\n",
    "\n",
    "        RAG_text = \"\"\"If you find some functions or classes are missing, please provide the name or the description of the missing functions or classes, with following format:\n",
    "\n",
    "Missing Functions:\n",
    "- FunctionName1 or FunctionDescription1\n",
    "- FunctionName2 or FunctionDescription2\n",
    "\n",
    "Missing Classes:\n",
    "- ClassName1 or ClassDescription1\n",
    "- ClassName2 or ClassDescription2\n",
    " \n",
    "Otherwise, try to finish the task by using class or functions defined above ONLY, do not trying to make your own functions or classes if it is not necessary.\"\"\"\n",
    "\n",
    "        prompt_message = f\"\"\"Provided code:\n",
    "\n",
    "{modules_str}\n",
    "{functions_str}\n",
    "{classes_str}\n",
    "\n",
    "read the code above, try to finish a task only based on these class or functions defined above. Assuming all the classes from the provided code are already imported, provide the code only without any explanation and you can use these class or functions directly without write it again in your code. Try to finish the task by using class or functions defined above ONLY, do not trying to make your own functions or classes if it is not necessary.\n",
    "\n",
    "{RAG_text if can_use_RAG else \"\"}\n",
    "\n",
    "task:\n",
    "{task_text}\"\"\"\n",
    "\n",
    "        if print_prompt:\n",
    "            print(\"Planner Prompt:\")\n",
    "            print(prompt_message)\n",
    "\n",
    "        return prompt_message\n",
    "    \n",
    "    def parse_imports(self, file_path):\n",
    "        with open(file_path, \"r\") as file:\n",
    "            tree = ast.parse(file.read(), filename=file_path)\n",
    "\n",
    "        imports = []\n",
    "        local_files = []\n",
    "        \n",
    "        for node in ast.walk(tree):\n",
    "            if isinstance(node, ast.Import):\n",
    "                for alias in node.names:\n",
    "                    if alias.asname:\n",
    "                        imports.append(f\"import {alias.name} as {alias.asname}\")\n",
    "                    else:\n",
    "                        imports.append(f\"import {alias.name}\")\n",
    "            elif isinstance(node, ast.ImportFrom):\n",
    "                if node.module:\n",
    "                    from_imports = ', '.join(\n",
    "                        [f\"{alias.name} as {alias.asname}\" if alias.asname else alias.name for alias in node.names]\n",
    "                    )\n",
    "                    imports.append(f\"from {node.module} import {from_imports}\")\n",
    "                    \n",
    "                    # Check if it's a relative import\n",
    "                    if node.level > 0:\n",
    "                        # Construct the potential file path\n",
    "                        base_dir = os.path.dirname(file_path)\n",
    "                        module_path = os.path.join(base_dir, *node.module.split('.'))\n",
    "                        \n",
    "                        # Consider .py and __init__.py as possible targets\n",
    "                        possible_files = [\n",
    "                            f\"{module_path}.py\",\n",
    "                            os.path.join(module_path, \"__init__.py\")\n",
    "                        ]\n",
    "                        \n",
    "                        for p_file in possible_files:\n",
    "                            if os.path.exists(p_file):\n",
    "                                local_files.append(p_file)\n",
    "                                # Recursively parse imports in the found local file\n",
    "                                nested_imports = self.parse_imports(p_file)\n",
    "                                for idx, nested_import in nested_imports.items():\n",
    "                                    nested_import['dependence'] = file_path\n",
    "                                return nested_imports\n",
    "        \n",
    "        # Return the imports as a dictionary\n",
    "        return {idx + 1: {\"text\": imp, \"local_file\": file_path, \"dependence\": local_files} for idx, imp in enumerate(imports)}\n",
    "    \n",
    "    def write_imports_to_file(self, source_files, target_file, return_text_only=False):\n",
    "        target_dir = os.path.dirname(target_file)\n",
    "        all_imports = []\n",
    "        visited_files = set()\n",
    "\n",
    "        def collect_imports(file_path):\n",
    "            if file_path in visited_files:\n",
    "                return\n",
    "            visited_files.add(file_path)\n",
    "\n",
    "            # Parse imports from the source file\n",
    "            parsed_imports = self.parse_imports(file_path)\n",
    "            \n",
    "            for import_info in parsed_imports.values():\n",
    "                imp_text = import_info['text']\n",
    "                imp_file = import_info['local_file']\n",
    "                dependencies = import_info['dependence']\n",
    "                \n",
    "                if imp_file not in visited_files:\n",
    "                    # Copy the local file to the target directory\n",
    "                    if imp_file not in source_files:\n",
    "                        destination = os.path.join(target_dir, os.path.basename(imp_file))\n",
    "                        shutil.copy2(imp_file, destination)\n",
    "                    \n",
    "                    # Recursively collect imports for dependencies\n",
    "                    for dep_file in dependencies:\n",
    "                        collect_imports(dep_file)\n",
    "                \n",
    "                all_imports.append(imp_text)\n",
    "\n",
    "        # Collect imports from each source file\n",
    "        for source_file in source_files:\n",
    "            collect_imports(source_file)\n",
    "\n",
    "        # Write collected imports to the target file\n",
    "        if not return_text_only:\n",
    "            with open(target_file, \"w\") as target:\n",
    "                import_text = \"\\n\".join(sorted(set(all_imports)))\n",
    "                target.write(import_text)\n",
    "\n",
    "        return import_text\n",
    "\n",
    "    def write_code_to_file(self, py_file_paths, running_code, store_path, use_py_files=True):\n",
    "        os.makedirs(store_path, exist_ok=True)\n",
    "        main_file_path = os.path.join(store_path, 'main.py')\n",
    "        if isinstance(py_file_paths, str):\n",
    "            py_file_paths = [py_file_paths]\n",
    "\n",
    "        if isinstance(py_file_paths, list):\n",
    "            if use_py_files:\n",
    "                for file_path in py_file_paths:\n",
    "                    shutil.copy(file_path, store_path)\n",
    "\n",
    "                with open(main_file_path, 'w') as main_file:\n",
    "                    for file_path in py_file_paths:\n",
    "                        module_name = os.path.splitext(os.path.basename(file_path))[0]\n",
    "                        main_file.write(f'from {module_name} import *\\n')\n",
    "                    \n",
    "                    main_file.write('\\n' + running_code)\n",
    "            else:\n",
    "                import_str = self.write_imports_to_file(py_file_paths, main_file_path)\n",
    "                functions = []\n",
    "                classes = []\n",
    "                for file_path in py_file_paths:\n",
    "                    functions_, classes_, modules_ = self.load_file(file_path)\n",
    "                    functions.extend(functions_)\n",
    "                    classes.extend(classes_)\n",
    "                with open(main_file_path, 'a') as main_file:\n",
    "                    main_file.write('\\n'.join(functions) + '\\n')\n",
    "                    main_file.write('\\n'.join(classes) + '\\n')\n",
    "\n",
    "                self.reorder_classes_functions(main_file_path)\n",
    "                \n",
    "                with open(main_file_path, 'a') as main_file:\n",
    "                    main_file.write('\\n' + running_code)\n",
    "        # if the input is a data frame\n",
    "        elif isinstance(py_file_paths, pd.DataFrame):\n",
    "            functions = py_file_paths[py_file_paths['type'] == 'function']['content'].tolist()\n",
    "            classes = py_file_paths[py_file_paths['type'] == 'class']['content'].tolist()\n",
    "            # get all unique file paths in the dataframe as a list\n",
    "            file_paths = py_file_paths['file_path'].unique().tolist()\n",
    "            if use_py_files:\n",
    "                for file_path in file_paths:\n",
    "                    shutil.copy(file_path, store_path)\n",
    "                with open(main_file_path, 'w') as main_file:\n",
    "                    for file_path in file_paths:\n",
    "                        module_name = os.path.splitext(os.path.basename(file_path))[0]\n",
    "                        main_file.write(f'from {module_name} import *\\n')\n",
    "            else:\n",
    "                print(\"debug_1\")\n",
    "                import_str = self.write_imports_to_file(file_paths, main_file_path)\n",
    "\n",
    "            with open(main_file_path, 'a') as main_file:\n",
    "                print(\"debug_2\")\n",
    "                main_file.write('\\n'.join(functions) + '\\n')\n",
    "                main_file.write('\\n'.join(classes) + '\\n')\n",
    "\n",
    "            self.reorder_classes_functions(main_file_path)\n",
    "\n",
    "            with open(main_file_path, 'a') as main_file:\n",
    "                print(\"debug_3\")\n",
    "                main_file.write('\\n' + running_code)\n",
    "\n",
    "        # with open(main_file_path, 'r') as file_temp_test:\n",
    "        #     print(\"@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\")\n",
    "        #     print(\"writting code to file\")\n",
    "        #     code_temp_test = file_temp_test.read()\n",
    "        #     print(code_temp_test)\n",
    "        #     print(\"@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\")\n",
    "        return main_file_path\n",
    "    \n",
    "    def get_class_dependencies(self, node):\n",
    "        dependencies = set()\n",
    "        for base in node.bases:\n",
    "            if isinstance(base, ast.Name):\n",
    "                dependencies.add(base.id)\n",
    "            elif isinstance(base, ast.Attribute):\n",
    "                dependencies.add(base.attr)\n",
    "        return dependencies\n",
    "\n",
    "    def reorder_classes_functions(self, file_path):\n",
    "        with open(file_path, 'r') as file:\n",
    "            source = file.read()\n",
    "\n",
    "        # Parse the source code into an AST\n",
    "        tree = ast.parse(source)\n",
    "        \n",
    "        # Identify classes and functions\n",
    "        class_defs = []\n",
    "        func_defs = []\n",
    "        main_block = []\n",
    "        \n",
    "        for node in tree.body:\n",
    "            if isinstance(node, ast.ClassDef):\n",
    "                class_defs.append(node)\n",
    "            elif isinstance(node, ast.FunctionDef):\n",
    "                func_defs.append(node)\n",
    "            else:\n",
    "                # Collect everything else (imports, statements, etc.)\n",
    "                if isinstance(node, ast.If) and isinstance(node.test, ast.Compare) and \\\n",
    "                isinstance(node.test.left, ast.Name) and node.test.left.id == \"__name__\":\n",
    "                    main_block = tree.body[tree.body.index(node):]\n",
    "                    break\n",
    "        \n",
    "        # Determine dependencies and reorder\n",
    "        ordered_classes = []\n",
    "        for cls in class_defs:\n",
    "            dependencies = set()\n",
    "            for base in cls.bases:\n",
    "                if isinstance(base, ast.Name):\n",
    "                    dependencies.add(base.id)\n",
    "            if not dependencies:\n",
    "                ordered_classes.append(cls)\n",
    "            else:\n",
    "                # Find the right place to insert this class based on dependencies\n",
    "                index = 0\n",
    "                for i, c in enumerate(ordered_classes):\n",
    "                    if c.name in dependencies:\n",
    "                        index = i + 1\n",
    "                ordered_classes.insert(index, cls)\n",
    "        \n",
    "        # Reconstruct the code with the correct order\n",
    "        ordered_code = ast.unparse(ordered_classes + func_defs)\n",
    "        \n",
    "        # Re-add any imports or other top-level elements before the main block\n",
    "        imports_and_others = [node for node in tree.body if not isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.If))]\n",
    "        final_code = ast.unparse(imports_and_others + ordered_classes + func_defs + main_block)\n",
    "        \n",
    "        # Write the reordered code back to the original file\n",
    "        with open(file_path, 'w') as file:\n",
    "            file.write(final_code)\n",
    "\n",
    "    def test_code_runability(self, main_file_path, max_runtime):\n",
    "        # print(\"@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\")\n",
    "        # print(\"Testing code runability\")\n",
    "        # with open(main_file_path, 'r') as file_test_temp:\n",
    "        #     code_test_temp = file_test_temp.read()\n",
    "        #     print(code_test_temp)\n",
    "        # print(\"@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\")\n",
    "        try:\n",
    "            start_time = time.time()\n",
    "            result = subprocess.run(\n",
    "                ['python', main_file_path],\n",
    "                cwd=self.store_path,\n",
    "                stdout=subprocess.PIPE,\n",
    "                stderr=subprocess.PIPE,\n",
    "                timeout=max_runtime\n",
    "            )\n",
    "            elapsed_time = time.time() - start_time\n",
    "\n",
    "            if result.returncode != 0:\n",
    "                # The code crashed\n",
    "                return False, result.stderr.decode('utf-8')\n",
    "            else:\n",
    "                # The code finished within max_runtime\n",
    "                return True, None\n",
    "\n",
    "        except subprocess.TimeoutExpired:\n",
    "            # The code's running time is out of max_runtime\n",
    "            return True, \"The code's running time exceeded the maximum runtime\"\n",
    "\n",
    "    def import_and_run(self, py_file_paths, running_code, store_path, max_runtime, use_py_files=True):\n",
    "\n",
    "        os.makedirs(store_path, exist_ok=True)\n",
    "        main_file_path = os.path.join(store_path, 'main.py')\n",
    "\n",
    "        if use_py_files:\n",
    "            for file_path in py_file_paths:\n",
    "                shutil.copy(file_path, store_path)\n",
    "\n",
    "            with open(main_file_path, 'w') as main_file:\n",
    "                for file_path in py_file_paths:\n",
    "                    module_name = os.path.splitext(os.path.basename(file_path))[0]\n",
    "                    main_file.write(f'from {module_name} import *\\n')\n",
    "                \n",
    "                main_file.write('\\n' + running_code)\n",
    "        else:\n",
    "            functions, classes, modules = self.load_file(py_file_paths[0])\n",
    "            with open(main_file_path, 'w') as main_file:\n",
    "                main_file.write('\\n'.join(modules) + '\\n')\n",
    "                main_file.write('\\n'.join(functions) + '\\n')\n",
    "                main_file.write('\\n'.join(classes) + '\\n')\n",
    "                main_file.write('\\n' + running_code)\n",
    "\n",
    "        try:\n",
    "            start_time = time.time()\n",
    "            result = subprocess.run(\n",
    "                ['python', main_file_path],\n",
    "                cwd=store_path,\n",
    "                stdout=subprocess.PIPE,\n",
    "                stderr=subprocess.PIPE,\n",
    "                timeout=max_runtime\n",
    "            )\n",
    "            elapsed_time = time.time() - start_time\n",
    "\n",
    "            if result.returncode != 0:\n",
    "                # The code crashed\n",
    "                return False, result.stderr.decode('utf-8')\n",
    "            else:\n",
    "                # The code finished within max_runtime\n",
    "                return True, None\n",
    "\n",
    "        except subprocess.TimeoutExpired:\n",
    "            # The code's running time is out of max_runtime\n",
    "            return True, \"The code's running time exceeded the maximum runtime\"\n",
    "        \n",
    "    def get_reflection_prompt(self, task_text, paths, running_code, message, print_prompt=False, can_use_RAG = False):\n",
    "\n",
    "        if isinstance(paths, str):\n",
    "            paths = [paths]\n",
    "\n",
    "        if isinstance(paths, list):\n",
    "            functions = []\n",
    "            classes = []\n",
    "            modules = []\n",
    "            for path in paths:\n",
    "                functions_, classes_, modules_ = self.load_file(path)\n",
    "                functions.extend(functions_)\n",
    "                classes.extend(classes_)\n",
    "                modules.extend(modules_)\n",
    "        # if the input is a data frame\n",
    "        elif isinstance(paths, pd.DataFrame):\n",
    "            functions = paths[paths['type'] == 'function']['content'].tolist()\n",
    "            classes = paths[paths['type'] == 'class']['content'].tolist()\n",
    "            modules = paths[paths['type'] == 'module']['content'].tolist()\n",
    "\n",
    "        modules_str = '\\n'.join(modules)\n",
    "        functions_str = '\\n\\n'.join(functions)\n",
    "        classes_str = '\\n\\n'.join(classes)\n",
    "\n",
    "        error_message = message\n",
    "\n",
    "        RAG_TEXT = \"\"\"If you find some functions or classes are missing, please provide the name or the description of the missing functions or classes, with following format:\n",
    "\n",
    "Missing Functions:\n",
    "- FunctionName1 or FunctionDescription1\n",
    "- FunctionName2 or FunctionDescription2\n",
    "\n",
    "Missing Classes:\n",
    "- ClassName1 or ClassDescription1\n",
    "- ClassName2 or ClassDescription2\n",
    "\"\"\"\n",
    "        RAG_query = \"\"\"If you think the provided code is sufficient to finish the task, please provide the fixed code without any explanation. \n",
    "\n",
    "Otherwise, if the code is missing some functions or classes, please provide the name or the description of the missing functions or classes, with following format:\n",
    "Missing Functions:\n",
    "- FunctionName1 or FunctionDescription1\n",
    "- FunctionName2 or FunctionDescription2\n",
    "\n",
    "Missing Classes:\n",
    "- ClassName1 or ClassDescription1\n",
    "- ClassName2 or ClassDescription2\"\"\"\n",
    "\n",
    "        normal_query = \"return the fixed code without any explanation:\"\n",
    "        reflection_prompt = f\"\"\"Provided code:\n",
    "\n",
    "{modules_str}\n",
    "{functions_str}\n",
    "{classes_str}\n",
    "\n",
    "read the code above, try to finish a task only based on these class or functions defined above. Assuming all the classes from the provided code are already imported, provide the code only without any explanation and you can use these class or functions directly without write it again in your code.\n",
    "\n",
    "{RAG_TEXT if can_use_RAG else \"\"}\n",
    "Otherwise, try to finish the task by using class or functions defined above ONLY, do not trying to make your own functions or classes if it is not necessary.\n",
    "\n",
    "task:\n",
    "{task_text}\n",
    "\n",
    "on the previous run, the code crashed, the code is provided below, please fix the error based on the error message, or ask for additional functions or classes, original code:\n",
    "{running_code}\n",
    "\n",
    "error message:\n",
    "{error_message}\n",
    "\n",
    "{RAG_query if can_use_RAG else normal_query}\"\"\"\n",
    "\n",
    "        if print_prompt:\n",
    "            print(\"Reflection Prompt:\")\n",
    "            print(reflection_prompt)\n",
    "        return reflection_prompt\n",
    "    \n",
    "    def run_external_tests(self, main_file_path, test_functions):\n",
    "\n",
    "        if test_functions is None or test_functions == []:\n",
    "            return True, \"No tests provided\"\n",
    "\n",
    "        # check if test_functions is a list of functions\n",
    "        if not all(callable(func) for func in test_functions):\n",
    "            raise ValueError(\"test_functions must be a list of functions\")\n",
    "\n",
    "        messages = []\n",
    "\n",
    "        for test_function in test_functions:\n",
    "            passed, message = test_function(main_file_path)\n",
    "            if not passed:\n",
    "                messages.append(message)\n",
    "\n",
    "        if not messages:\n",
    "            return True, \"All tests passed\"\n",
    "        else:\n",
    "            return False, \"\\n\".join(messages)\n",
    "\n",
    "    def function_to_string(self, func):\n",
    "        return inspect.getsource(func)\n",
    "\n",
    "    def string_to_function(self, func_str, func_name):\n",
    "        \"\"\"Converts a string representation of a Python function back to a function object.\"\"\"\n",
    "        exec(func_str)\n",
    "        return locals()[func_name]\n",
    "\n",
    "    def get_code_interpreter_prompt(self, code_text, print_prompt=False):\n",
    "\n",
    "        prompt_message = f\"\"\"Given the following code, please provide a brief description of what the code does. \n",
    "Code:\n",
    "{code_text}\"\"\"\n",
    "        \n",
    "        if print_prompt:\n",
    "            print(\"Code Interpreter Prompt:\")\n",
    "            print(prompt_message)\n",
    "\n",
    "        return prompt_message\n",
    "\n",
    "    async def run_planner(self, task_text, paths, max_runtime = 10, max_error_num = 3, use_py_files=True, test_functions=[], RAG_method = \"AGENT\", print_prompt=False, use_code_interpreter=False):\n",
    "        print(\"########################################################################################\")\n",
    "        print(\"planner started.\")\n",
    "\n",
    "        if isinstance(paths, str):\n",
    "            paths = [paths]\n",
    "\n",
    "        functions = []\n",
    "        classes = []\n",
    "        modules = []\n",
    "        if paths:\n",
    "            for path in paths:\n",
    "                functions_, classes_, modules_ = self.load_file(path)\n",
    "                functions.extend(functions_)\n",
    "                classes.extend(classes_)\n",
    "                modules.extend(modules_)\n",
    "\n",
    "        if self.json_df is None:\n",
    "            self.json_df = self.files_to_json_to_dataframe(paths, self.client, save_path=self.store_path)\n",
    "        else:\n",
    "            test_function_texts = [self.function_to_string(func) for func in test_functions]\n",
    "            missing_functions, missing_classes, missing_evaluation = self.find_missing_entities(self.json_df, functions, classes, test_function_texts)\n",
    "            if missing_functions or missing_classes:\n",
    "                missing_dict = {\"functions\": missing_functions, \"classes\": missing_classes, \"evaluation\": missing_evaluation}\n",
    "                missing_json = self.get_json_data(missing_dict, self.client, model=self.model, save_path=self.store_path)\n",
    "                missing_df = self.load_json_to_dataframe(missing_json)\n",
    "                self.json_df = pd.concat([self.json_df, missing_df], ignore_index=True).drop_duplicates(subset=['type', 'name'])\n",
    "\n",
    "        if self.RAG_config:\n",
    "            print(\"########################################################################################\")\n",
    "            print(\"RAG started.\")\n",
    "            filtered_df, eval_filtered_df = await self.perform_RAG_step(RAG_method, task_text, print_prompt)\n",
    "            \n",
    "            input_source = filtered_df\n",
    "            test_functions_content = eval_filtered_df['content'].tolist()\n",
    "            test_functions_names = eval_filtered_df['name'].tolist()\n",
    "            test_functions = [self.string_to_function(func_str, func_name) for func_str, func_name in zip(test_functions_content, test_functions_names)]\n",
    "            \n",
    "            print(\"selected functions and classes:\")\n",
    "            print(filtered_df[['type', 'name']])\n",
    "            print(\"RAG step finished.\")\n",
    "        else:\n",
    "            input_source = paths\n",
    "\n",
    "        prompt_message = self.get_planner_prompt(task_text, input_source, print_prompt, can_use_RAG=self.RAG_config is not None)\n",
    "        gpt_response = self.get_llm_response(prompt_message)\n",
    "        running_code, output_message = self.extract_code_from_gpt_output(gpt_response)\n",
    "\n",
    "        if output_message == \"RAG\":\n",
    "            print(\"########################################################################################\")\n",
    "            print(\"Asking for missing functions and classes.\")\n",
    "            print(running_code)\n",
    "            new_filtered_df, new_eval_filtered_df = await self.perform_RAG_step(RAG_method, running_code, print_prompt)\n",
    "            input_source = pd.concat([input_source, new_filtered_df], ignore_index=True).drop_duplicates(subset=['type', 'name'])\n",
    "            print(\"selected functions and classes:\")\n",
    "            print(input_source[['type', 'name']])\n",
    "            prompt_message = self.get_planner_prompt(task_text, input_source, print_prompt, can_use_RAG=False)\n",
    "            gpt_response = self.get_llm_response(prompt_message)\n",
    "            running_code, output_message = self.extract_code_from_gpt_output(gpt_response)\n",
    "            print(\"debug\")\n",
    "        elif output_message is None:\n",
    "            prompt_message = self.get_planner_prompt(task_text, input_source, print_prompt, can_use_RAG=False)\n",
    "            gpt_response = self.get_llm_response(prompt_message)\n",
    "            running_code, output_message = self.extract_code_from_gpt_output(gpt_response)\n",
    "        else:\n",
    "            pass\n",
    "\n",
    "        store_path = self.store_path\n",
    "\n",
    "        print(\"########################################################################################\")\n",
    "        print(f\"current running code:\\n {running_code}\")\n",
    "        main_file_path = self.write_code_to_file(input_source, running_code, store_path, use_py_files)\n",
    "\n",
    "        print(\"########################################################################################\")\n",
    "        print(\"testing the code.\")\n",
    "        result, message = self.test_code_runability(main_file_path, max_runtime)\n",
    "        # result, message = self.import_and_run(paths, running_code, store_path, max_runtime, use_py_files)\n",
    "\n",
    "        if result:\n",
    "            result, external_message = self.run_external_tests(main_file_path, test_functions)\n",
    "            message = external_message if not result else message\n",
    "\n",
    "        error_num = 0\n",
    "        while not result and error_num < max_error_num:\n",
    "            print(\"########################################################################################\")\n",
    "            print(f\"current running code:\\n {running_code}\")\n",
    "            print(f\"error message:\\n {message}\")\n",
    "            print(f\"planner error, trying to fix the error {error_num + 1} time(s).\")\n",
    "\n",
    "            reflection_prompt = self.get_reflection_prompt(task_text, input_source, running_code, message, print_prompt, can_use_RAG=self.RAG_config is not None)\n",
    "            gpt_response = self.get_llm_response(reflection_prompt)\n",
    "            running_code, output_message = self.extract_code_from_gpt_output(gpt_response)\n",
    "\n",
    "            # if \"NameError\" in message and \"is not defined\" in message:\n",
    "            #     output_message = \"RAG\"\n",
    "            #     running_code = message\n",
    "\n",
    "            if output_message == \"RAG\":\n",
    "                print(\"########################################################################################\")\n",
    "                print(\"Asking for missing functions and classes.\")\n",
    "                print(running_code)\n",
    "                new_filtered_df, new_eval_filtered_df = await self.perform_RAG_step(RAG_method, running_code, print_prompt)\n",
    "                input_source = pd.concat([input_source, new_filtered_df], ignore_index=True).drop_duplicates(subset=['type', 'name'])\n",
    "                print(\"selected functions and classes:\")\n",
    "                print(input_source[['type', 'name']])\n",
    "                prompt_message = self.get_planner_prompt(task_text, input_source, print_prompt, can_use_RAG=False)\n",
    "                gpt_response = self.get_llm_response(prompt_message)\n",
    "                running_code, output_message = self.extract_code_from_gpt_output(gpt_response)\n",
    "            elif output_message is None:\n",
    "                print(\"########################################################################################\")\n",
    "                print(\"invalid planner output, trying to re-run the planner.\")\n",
    "                prompt_message = self.get_planner_prompt(task_text, input_source, print_prompt, can_use_RAG=False)\n",
    "                gpt_response = self.get_llm_response(prompt_message)\n",
    "                running_code, output_message = self.extract_code_from_gpt_output(gpt_response)\n",
    "            else:\n",
    "                pass\n",
    "\n",
    "            print(\"########################################################################################\")\n",
    "            print(f\"current running code:\\n {running_code}\")\n",
    "            main_file_path = self.write_code_to_file(input_source, running_code, store_path, use_py_files)\n",
    "\n",
    "            print(\"########################################################################################\")\n",
    "            print(\"testing the code.\")\n",
    "            result, message = self.test_code_runability(main_file_path, max_runtime)\n",
    "\n",
    "            if result:\n",
    "                result, external_message = self.run_external_tests(main_file_path, test_functions)\n",
    "                message = external_message if not result else message\n",
    "\n",
    "            # result, message = self.import_and_run(paths, running_code, store_path, max_runtime, use_py_files)\n",
    "\n",
    "            error_num += 1\n",
    "\n",
    "        if result:\n",
    "            print(\"planner finished successfully.\")\n",
    "            print(\"final running code:\")\n",
    "            print(running_code)\n",
    "\n",
    "            # score the code\n",
    "            task_embedding = await self.get_embedding_async(task_text, client_async, self.embedding_model)\n",
    "            if use_code_interpreter:\n",
    "                interpreter_prompt = self.get_code_interpreter_prompt(running_code, print_prompt)\n",
    "                interpreter_response = self.get_llm_response(interpreter_prompt)\n",
    "                code_embedding = await self.get_embedding_async(interpreter_response, client_async, self.embedding_model)\n",
    "                print(\"code interpreter response:\")\n",
    "                print(interpreter_response)\n",
    "            else:\n",
    "                code_embedding = await self.get_embedding_async(running_code, client_async, self.embedding_model)\n",
    "\n",
    "            similarity_score = self.cosine_similarity(task_embedding, code_embedding)\n",
    "            print(f\"similarity score: {similarity_score}\")\n",
    "\n",
    "        else:\n",
    "            print(\"planner finished with errors.\")\n",
    "            print(f\"error message:\\n {message}\")\n",
    "            print(f\"similarity score: 0\")\n",
    "\n",
    "        # debug\n",
    "        print(\"test functions:\")\n",
    "        print(test_functions)\n",
    "        return result, message"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "########################################################################################\n",
      "planner started.\n",
      "########################################################################################\n",
      "RAG started.\n",
      "RAG step started.\n",
      "selected functions and classes:\n",
      "     type                name\n",
      "9   class          Snake_Game\n",
      "10  class               Snake\n",
      "11  class                Food\n",
      "14  class       HealthFeature\n",
      "15  class   SpeedBoostFeature\n",
      "16  class        EnemyFeature\n",
      "18  class  VarietyFoodFeature\n",
      "RAG step finished.\n",
      "########################################################################################\n",
      "Asking for missing functions and classes.\n",
      "Missing Classes:\n",
      "- Feature or Base class for different features (HealthFeature, SpeedBoostFeature, etc.) depends on it.\n",
      "RAG step started.\n",
      "selected functions and classes:\n",
      "    type                name\n",
      "0  class          Snake_Game\n",
      "1  class               Snake\n",
      "2  class                Food\n",
      "3  class       HealthFeature\n",
      "4  class   SpeedBoostFeature\n",
      "5  class        EnemyFeature\n",
      "6  class  VarietyFoodFeature\n",
      "7  class             Feature\n",
      "debug\n",
      "########################################################################################\n",
      "current running code:\n",
      " import pygame\n",
      "\n",
      "def main():\n",
      "    game = Snake_Game()\n",
      "    \n",
      "    health_feature = HealthFeature(initial_health=3)\n",
      "    speed_boost_feature = SpeedBoostFeature(speed_increase=5)\n",
      "    enemy_feature = EnemyFeature()\n",
      "    variety_food_feature = VarietyFoodFeature()\n",
      "    \n",
      "    game.add_feature(health_feature)\n",
      "    game.add_feature(speed_boost_feature)\n",
      "    game.add_feature(enemy_feature)\n",
      "    game.add_feature(variety_food_feature)\n",
      "    \n",
      "    game.run()\n",
      "    pygame.quit()\n",
      "\n",
      "if __name__ == \"__main__\":\n",
      "    main()\n",
      "debug_1\n",
      "debug_2\n",
      "debug_3\n",
      "########################################################################################\n",
      "testing the code.\n",
      "########################################################################################\n",
      "current running code:\n",
      " import pygame\n",
      "\n",
      "def main():\n",
      "    game = Snake_Game()\n",
      "    \n",
      "    health_feature = HealthFeature(initial_health=3)\n",
      "    speed_boost_feature = SpeedBoostFeature(speed_increase=5)\n",
      "    enemy_feature = EnemyFeature()\n",
      "    variety_food_feature = VarietyFoodFeature()\n",
      "    \n",
      "    game.add_feature(health_feature)\n",
      "    game.add_feature(speed_boost_feature)\n",
      "    game.add_feature(enemy_feature)\n",
      "    game.add_feature(variety_food_feature)\n",
      "    \n",
      "    game.run()\n",
      "    pygame.quit()\n",
      "\n",
      "if __name__ == \"__main__\":\n",
      "    main()\n",
      "error message:\n",
      " Traceback (most recent call last):\n",
      "  File \"D:\\Python_project\\llm_agent\\output\\snake_game_test\\main.py\", line 249, in <module>\n",
      "    main()\n",
      "  File \"D:\\Python_project\\llm_agent\\output\\snake_game_test\\main.py\", line 233, in main\n",
      "    game = Snake_Game()\n",
      "  File \"D:\\Python_project\\llm_agent\\output\\snake_game_test\\main.py\", line 156, in __init__\n",
      "    self.ui = UI(self.screen, self.snake, self.food)\n",
      "NameError: name 'UI' is not defined\n",
      "\n",
      "planner error, trying to fix the error 1 time(s).\n",
      "########################################################################################\n",
      "current running code:\n",
      " import pygame\n",
      "import random\n",
      "from abc import ABC, abstractmethod\n",
      "\n",
      "class UI:\n",
      "    def __init__(self, screen, snake, food):\n",
      "        self.screen = screen\n",
      "        self.snake = snake\n",
      "        self.food = food\n",
      "\n",
      "    def draw(self):\n",
      "        # Draw snake\n",
      "        for pos in self.snake.positions:\n",
      "            pygame.draw.rect(self.screen, (0, 255, 0), pygame.Rect(pos[0], pos[1], 10, 10))\n",
      "        # Draw food\n",
      "        pygame.draw.rect(self.screen, (255, 0, 0), pygame.Rect(self.food.position[0], self.food.position[1], 10, 10))\n",
      "\n",
      "class Snake_Game:\n",
      "    # ... (rest of the code remains unchanged)\n",
      "\n",
      "def main():\n",
      "    game = Snake_Game()\n",
      "    \n",
      "    health_feature = HealthFeature(initial_health=3)\n",
      "    speed_boost_feature = SpeedBoostFeature(speed_increase=5)\n",
      "    enemy_feature = EnemyFeature()\n",
      "    variety_food_feature = VarietyFoodFeature()\n",
      "    \n",
      "    game.add_feature(health_feature)\n",
      "    game.add_feature(speed_boost_feature)\n",
      "    game.add_feature(enemy_feature)\n",
      "    game.add_feature(variety_food_feature)\n",
      "    \n",
      "    game.run()\n",
      "    pygame.quit()\n",
      "\n",
      "if __name__ == \"__main__\":\n",
      "    main()\n",
      "debug_1\n",
      "debug_2\n",
      "debug_3\n",
      "########################################################################################\n",
      "testing the code.\n",
      "########################################################################################\n",
      "current running code:\n",
      " import pygame\n",
      "import random\n",
      "from abc import ABC, abstractmethod\n",
      "\n",
      "class UI:\n",
      "    def __init__(self, screen, snake, food):\n",
      "        self.screen = screen\n",
      "        self.snake = snake\n",
      "        self.food = food\n",
      "\n",
      "    def draw(self):\n",
      "        # Draw snake\n",
      "        for pos in self.snake.positions:\n",
      "            pygame.draw.rect(self.screen, (0, 255, 0), pygame.Rect(pos[0], pos[1], 10, 10))\n",
      "        # Draw food\n",
      "        pygame.draw.rect(self.screen, (255, 0, 0), pygame.Rect(self.food.position[0], self.food.position[1], 10, 10))\n",
      "\n",
      "class Snake_Game:\n",
      "    # ... (rest of the code remains unchanged)\n",
      "\n",
      "def main():\n",
      "    game = Snake_Game()\n",
      "    \n",
      "    health_feature = HealthFeature(initial_health=3)\n",
      "    speed_boost_feature = SpeedBoostFeature(speed_increase=5)\n",
      "    enemy_feature = EnemyFeature()\n",
      "    variety_food_feature = VarietyFoodFeature()\n",
      "    \n",
      "    game.add_feature(health_feature)\n",
      "    game.add_feature(speed_boost_feature)\n",
      "    game.add_feature(enemy_feature)\n",
      "    game.add_feature(variety_food_feature)\n",
      "    \n",
      "    game.run()\n",
      "    pygame.quit()\n",
      "\n",
      "if __name__ == \"__main__\":\n",
      "    main()\n",
      "error message:\n",
      "   File \"D:\\Python_project\\llm_agent\\output\\snake_game_test\\main.py\", line 250\n",
      "    def main():\n",
      "    ^\n",
      "IndentationError: expected an indented block\n",
      "\n",
      "planner error, trying to fix the error 2 time(s).\n",
      "########################################################################################\n",
      "current running code:\n",
      " import pygame\n",
      "import random\n",
      "from abc import ABC, abstractmethod\n",
      "\n",
      "class UI:\n",
      "    def __init__(self, screen, snake, food):\n",
      "        self.screen = screen\n",
      "        self.snake = snake\n",
      "        self.food = food\n",
      "\n",
      "    def draw(self):\n",
      "        # Draw snake\n",
      "        for pos in self.snake.positions:\n",
      "            pygame.draw.rect(self.screen, (0, 255, 0), pygame.Rect(pos[0], pos[1], 10, 10))\n",
      "        # Draw food\n",
      "        pygame.draw.rect(self.screen, (255, 0, 0), pygame.Rect(self.food.position[0], self.food.position[1], 10, 10))\n",
      "\n",
      "class Snake_Game:\n",
      "    def __init__(self):\n",
      "        pygame.init()\n",
      "        self.screen = pygame.display.set_mode((800, 600))\n",
      "        self.clock = pygame.time.Clock()\n",
      "        self.snake = Snake()\n",
      "        self.food = Food()\n",
      "        self.ui = UI(self.screen, self.snake, self.food)\n",
      "        self.running = True\n",
      "        self.features = []\n",
      "        self.score = 0\n",
      "\n",
      "    def add_feature(self, feature):\n",
      "        self.features.append(feature)\n",
      "        feature.modify_game(self)\n",
      "\n",
      "    def run(self):\n",
      "        while self.running:\n",
      "            self.handle_events()\n",
      "            self.update()\n",
      "            self.render()\n",
      "            self.clock.tick(15)\n",
      "\n",
      "    def handle_events(self):\n",
      "        for event in pygame.event.get():\n",
      "            if event.type == pygame.QUIT:\n",
      "                self.running = False\n",
      "            for feature in self.features:\n",
      "                feature.handle_event(event)\n",
      "\n",
      "    def update(self):\n",
      "        self.snake.move()\n",
      "        if self.snake.check_collision(self.food.position):\n",
      "            self.snake.grow()\n",
      "            self.food.reposition()\n",
      "            self.score += 1\n",
      "        for feature in self.features:\n",
      "            feature.update()\n",
      "\n",
      "    def render(self):\n",
      "        self.screen.fill((0, 0, 0))\n",
      "        self.ui.draw()\n",
      "        pygame.display.flip()\n",
      "\n",
      "class Snake:\n",
      "    def __init__(self):\n",
      "        self.positions = [(100, 100), (90, 100), (80, 100)]\n",
      "        self.direction = (10, 0)\n",
      "        self.score = 0\n",
      "\n",
      "    def move(self):\n",
      "        head = self.positions[0]\n",
      "        new_head = (head[0] + self.direction[0], head[1] + self.direction[1])\n",
      "        self.positions = [new_head] + self.positions[:-1]\n",
      "\n",
      "    def grow(self):\n",
      "        self.positions.append(self.positions[-1])\n",
      "        self.score += 1\n",
      "\n",
      "    def check_collision(self, position):\n",
      "        return self.positions[0] == position\n",
      "\n",
      "    def set_direction(self, direction):\n",
      "        self.direction = direction\n",
      "\n",
      "    def check_self_collision(self):\n",
      "        return self.positions[0] in self.positions[1:]\n",
      "\n",
      "    def check_wall_collision(self):\n",
      "        head = self.positions[0]\n",
      "        return not (0 <= head[0] < 800 and 0 <= head[1] < 600)\n",
      "\n",
      "class Food:\n",
      "    def __init__(self, food_type='normal'):\n",
      "        self.type = food_type\n",
      "        self.position = (random.randint(0, 79) * 10, random.randint(0, 59) * 10)\n",
      "\n",
      "    def reposition(self):\n",
      "        self.position = (random.randint(0, 79) * 10, random.randint(0, 59) * 10)\n",
      "\n",
      "class HealthFeature(Feature):\n",
      "    def __init__(self, initial_health=3):\n",
      "        self.health = initial_health\n",
      "\n",
      "    def modify_game(self, game):\n",
      "        self.game = game\n",
      "        self.snake = game.snake\n",
      "        self.ui = game.ui\n",
      "        original_draw = self.ui.draw\n",
      "\n",
      "        def draw_with_health():\n",
      "            original_draw()\n",
      "            self.draw_health()\n",
      "        self.ui.draw = draw_with_health\n",
      "        original_update = game.update\n",
      "\n",
      "        def update_with_health():\n",
      "            if self.snake.check_self_collision() or self.snake.check_wall_collision():\n",
      "                self.health -= 1\n",
      "                if self.health <= 0:\n",
      "                    game.running = False\n",
      "                else:\n",
      "                    self.respawn_snake()\n",
      "            else:\n",
      "                original_update()\n",
      "        game.update = update_with_health\n",
      "\n",
      "    def draw_health(self):\n",
      "        font = pygame.font.Font(None, 36)\n",
      "        health_text = font.render(f'Health: {self.health}', True, (255, 255, 255))\n",
      "        self.ui.screen.blit(health_text, (10, 10))\n",
      "\n",
      "    def respawn_snake(self):\n",
      "        self.snake.positions = [(100, 100), (90, 100), (80, 100)]\n",
      "        self.snake.direction = (10, 0)\n",
      "\n",
      "class SpeedBoostFeature(Feature):\n",
      "    def __init__(self, speed_increase=10):\n",
      "        self.speed_increase = speed_increase\n",
      "        self.original_speed = 15\n",
      "        self.boost_active = False\n",
      "\n",
      "    def modify_game(self, game):\n",
      "        self.game = game\n",
      "        original_handle_events = game.handle_events\n",
      "\n",
      "        def handle_events_with_speed_boost():\n",
      "            original_handle_events()\n",
      "            keys = pygame.key.get_pressed()\n",
      "            self.boost_active = keys[pygame.K_SPACE]\n",
      "        game.handle_events = handle_events_with_speed_boost\n",
      "        original_update = game.update\n",
      "\n",
      "        def update_with_speed_boost():\n",
      "            game.clock.tick(self.original_speed + self.speed_increase if self.boost_active else self.original_speed)\n",
      "            original_update()\n",
      "        game.update = update_with_speed_boost\n",
      "\n",
      "class EnemyFeature(Feature):\n",
      "    def __init__(self):\n",
      "        self.enemies = []\n",
      "\n",
      "    def modify_game(self, game):\n",
      "        self.game = game\n",
      "        self.snake = game.snake\n",
      "        self.ui = game.ui\n",
      "        self.add_enemy()\n",
      "        original_update = game.update\n",
      "\n",
      "        def update_with_enemies():\n",
      "            for enemy in self.enemies:\n",
      "                if self.snake.check_collision(enemy.position):\n",
      "                    game.running = False\n",
      "                enemy.move()\n",
      "            original_update()\n",
      "        game.update = update_with_enemies\n",
      "        original_draw = self.ui.draw\n",
      "\n",
      "        def draw_with_enemies():\n",
      "            original_draw()\n",
      "            self.draw_enemies()\n",
      "        self.ui.draw = draw_with_enemies\n",
      "\n",
      "    def add_enemy(self):\n",
      "        self.enemies.append(Enemy())\n",
      "\n",
      "    def draw_enemies(self):\n",
      "        for enemy in self.enemies:\n",
      "            pygame.draw.rect(self.ui.screen, (255, 255, 0), pygame.Rect(enemy.position[0], enemy.position[1], 10, 10))\n",
      "\n",
      "class VarietyFoodFeature(Feature):\n",
      "    def __init__(self):\n",
      "        self.foods = []\n",
      "\n",
      "    def modify_game(self, game):\n",
      "        self.game = game\n",
      "        self.snake = game.snake\n",
      "        self.ui = game.ui\n",
      "        self.add_variety_food()\n",
      "        original_update = game.update\n",
      "\n",
      "        def update_with_variety_food():\n",
      "            for food in self.foods:\n",
      "                if self.snake.check_collision(food.position):\n",
      "                    self.snake.grow()\n",
      "                    if food.type == 'big':\n",
      "                        self.snake.grow()\n",
      "                    food.reposition()\n",
      "            original_update()\n",
      "        game.update = update_with_variety_food\n",
      "        original_draw = self.ui.draw\n",
      "\n",
      "        def draw_with_variety_food():\n",
      "            original_draw()\n",
      "            self.draw_foods()\n",
      "        self.ui.draw = draw_with_variety_food\n",
      "\n",
      "    def add_variety_food(self):\n",
      "        self.foods.append(Food('normal'))\n",
      "        self.foods.append(Food('big'))\n",
      "\n",
      "    def draw_foods(self):\n",
      "        for food in self.foods:\n",
      "            color = (255, 0, 0) if food.type == 'normal' else (0, 0, 255)\n",
      "            pygame.draw.rect(self.ui.screen, color, pygame.Rect(food.position[0], food.position[1], 10, 10))\n",
      "\n",
      "class Feature(ABC):\n",
      "    @abstractmethod\n",
      "    def modify_game(self, game):\n",
      "        pass\n",
      "\n",
      "    def handle_event(self, event):\n",
      "        pass\n",
      "\n",
      "    def update(self):\n",
      "        pass\n",
      "\n",
      "def main():\n",
      "    game = Snake_Game()\n",
      "    \n",
      "    health_feature = HealthFeature(initial_health=3)\n",
      "    speed_boost_feature = SpeedBoostFeature(speed_increase=5)\n",
      "    enemy_feature = EnemyFeature()\n",
      "    variety_food_feature = VarietyFoodFeature()\n",
      "    \n",
      "    game.add_feature(health_feature)\n",
      "    game.add_feature(speed_boost_feature)\n",
      "    game.add_feature(enemy_feature)\n",
      "    game.add_feature(variety_food_feature)\n",
      "    \n",
      "    game.run()\n",
      "    pygame.quit()\n",
      "\n",
      "if __name__ == \"__main__\":\n",
      "    main()\n",
      "debug_1\n",
      "debug_2\n",
      "debug_3\n",
      "########################################################################################\n",
      "testing the code.\n",
      "########################################################################################\n",
      "current running code:\n",
      " import pygame\n",
      "import random\n",
      "from abc import ABC, abstractmethod\n",
      "\n",
      "class UI:\n",
      "    def __init__(self, screen, snake, food):\n",
      "        self.screen = screen\n",
      "        self.snake = snake\n",
      "        self.food = food\n",
      "\n",
      "    def draw(self):\n",
      "        # Draw snake\n",
      "        for pos in self.snake.positions:\n",
      "            pygame.draw.rect(self.screen, (0, 255, 0), pygame.Rect(pos[0], pos[1], 10, 10))\n",
      "        # Draw food\n",
      "        pygame.draw.rect(self.screen, (255, 0, 0), pygame.Rect(self.food.position[0], self.food.position[1], 10, 10))\n",
      "\n",
      "class Snake_Game:\n",
      "    def __init__(self):\n",
      "        pygame.init()\n",
      "        self.screen = pygame.display.set_mode((800, 600))\n",
      "        self.clock = pygame.time.Clock()\n",
      "        self.snake = Snake()\n",
      "        self.food = Food()\n",
      "        self.ui = UI(self.screen, self.snake, self.food)\n",
      "        self.running = True\n",
      "        self.features = []\n",
      "        self.score = 0\n",
      "\n",
      "    def add_feature(self, feature):\n",
      "        self.features.append(feature)\n",
      "        feature.modify_game(self)\n",
      "\n",
      "    def run(self):\n",
      "        while self.running:\n",
      "            self.handle_events()\n",
      "            self.update()\n",
      "            self.render()\n",
      "            self.clock.tick(15)\n",
      "\n",
      "    def handle_events(self):\n",
      "        for event in pygame.event.get():\n",
      "            if event.type == pygame.QUIT:\n",
      "                self.running = False\n",
      "            for feature in self.features:\n",
      "                feature.handle_event(event)\n",
      "\n",
      "    def update(self):\n",
      "        self.snake.move()\n",
      "        if self.snake.check_collision(self.food.position):\n",
      "            self.snake.grow()\n",
      "            self.food.reposition()\n",
      "            self.score += 1\n",
      "        for feature in self.features:\n",
      "            feature.update()\n",
      "\n",
      "    def render(self):\n",
      "        self.screen.fill((0, 0, 0))\n",
      "        self.ui.draw()\n",
      "        pygame.display.flip()\n",
      "\n",
      "class Snake:\n",
      "    def __init__(self):\n",
      "        self.positions = [(100, 100), (90, 100), (80, 100)]\n",
      "        self.direction = (10, 0)\n",
      "        self.score = 0\n",
      "\n",
      "    def move(self):\n",
      "        head = self.positions[0]\n",
      "        new_head = (head[0] + self.direction[0], head[1] + self.direction[1])\n",
      "        self.positions = [new_head] + self.positions[:-1]\n",
      "\n",
      "    def grow(self):\n",
      "        self.positions.append(self.positions[-1])\n",
      "        self.score += 1\n",
      "\n",
      "    def check_collision(self, position):\n",
      "        return self.positions[0] == position\n",
      "\n",
      "    def set_direction(self, direction):\n",
      "        self.direction = direction\n",
      "\n",
      "    def check_self_collision(self):\n",
      "        return self.positions[0] in self.positions[1:]\n",
      "\n",
      "    def check_wall_collision(self):\n",
      "        head = self.positions[0]\n",
      "        return not (0 <= head[0] < 800 and 0 <= head[1] < 600)\n",
      "\n",
      "class Food:\n",
      "    def __init__(self, food_type='normal'):\n",
      "        self.type = food_type\n",
      "        self.position = (random.randint(0, 79) * 10, random.randint(0, 59) * 10)\n",
      "\n",
      "    def reposition(self):\n",
      "        self.position = (random.randint(0, 79) * 10, random.randint(0, 59) * 10)\n",
      "\n",
      "class HealthFeature(Feature):\n",
      "    def __init__(self, initial_health=3):\n",
      "        self.health = initial_health\n",
      "\n",
      "    def modify_game(self, game):\n",
      "        self.game = game\n",
      "        self.snake = game.snake\n",
      "        self.ui = game.ui\n",
      "        original_draw = self.ui.draw\n",
      "\n",
      "        def draw_with_health():\n",
      "            original_draw()\n",
      "            self.draw_health()\n",
      "        self.ui.draw = draw_with_health\n",
      "        original_update = game.update\n",
      "\n",
      "        def update_with_health():\n",
      "            if self.snake.check_self_collision() or self.snake.check_wall_collision():\n",
      "                self.health -= 1\n",
      "                if self.health <= 0:\n",
      "                    game.running = False\n",
      "                else:\n",
      "                    self.respawn_snake()\n",
      "            else:\n",
      "                original_update()\n",
      "        game.update = update_with_health\n",
      "\n",
      "    def draw_health(self):\n",
      "        font = pygame.font.Font(None, 36)\n",
      "        health_text = font.render(f'Health: {self.health}', True, (255, 255, 255))\n",
      "        self.ui.screen.blit(health_text, (10, 10))\n",
      "\n",
      "    def respawn_snake(self):\n",
      "        self.snake.positions = [(100, 100), (90, 100), (80, 100)]\n",
      "        self.snake.direction = (10, 0)\n",
      "\n",
      "class SpeedBoostFeature(Feature):\n",
      "    def __init__(self, speed_increase=10):\n",
      "        self.speed_increase = speed_increase\n",
      "        self.original_speed = 15\n",
      "        self.boost_active = False\n",
      "\n",
      "    def modify_game(self, game):\n",
      "        self.game = game\n",
      "        original_handle_events = game.handle_events\n",
      "\n",
      "        def handle_events_with_speed_boost():\n",
      "            original_handle_events()\n",
      "            keys = pygame.key.get_pressed()\n",
      "            self.boost_active = keys[pygame.K_SPACE]\n",
      "        game.handle_events = handle_events_with_speed_boost\n",
      "        original_update = game.update\n",
      "\n",
      "        def update_with_speed_boost():\n",
      "            game.clock.tick(self.original_speed + self.speed_increase if self.boost_active else self.original_speed)\n",
      "            original_update()\n",
      "        game.update = update_with_speed_boost\n",
      "\n",
      "class EnemyFeature(Feature):\n",
      "    def __init__(self):\n",
      "        self.enemies = []\n",
      "\n",
      "    def modify_game(self, game):\n",
      "        self.game = game\n",
      "        self.snake = game.snake\n",
      "        self.ui = game.ui\n",
      "        self.add_enemy()\n",
      "        original_update = game.update\n",
      "\n",
      "        def update_with_enemies():\n",
      "            for enemy in self.enemies:\n",
      "                if self.snake.check_collision(enemy.position):\n",
      "                    game.running = False\n",
      "                enemy.move()\n",
      "            original_update()\n",
      "        game.update = update_with_enemies\n",
      "        original_draw = self.ui.draw\n",
      "\n",
      "        def draw_with_enemies():\n",
      "            original_draw()\n",
      "            self.draw_enemies()\n",
      "        self.ui.draw = draw_with_enemies\n",
      "\n",
      "    def add_enemy(self):\n",
      "        self.enemies.append(Enemy())\n",
      "\n",
      "    def draw_enemies(self):\n",
      "        for enemy in self.enemies:\n",
      "            pygame.draw.rect(self.ui.screen, (255, 255, 0), pygame.Rect(enemy.position[0], enemy.position[1], 10, 10))\n",
      "\n",
      "class VarietyFoodFeature(Feature):\n",
      "    def __init__(self):\n",
      "        self.foods = []\n",
      "\n",
      "    def modify_game(self, game):\n",
      "        self.game = game\n",
      "        self.snake = game.snake\n",
      "        self.ui = game.ui\n",
      "        self.add_variety_food()\n",
      "        original_update = game.update\n",
      "\n",
      "        def update_with_variety_food():\n",
      "            for food in self.foods:\n",
      "                if self.snake.check_collision(food.position):\n",
      "                    self.snake.grow()\n",
      "                    if food.type == 'big':\n",
      "                        self.snake.grow()\n",
      "                    food.reposition()\n",
      "            original_update()\n",
      "        game.update = update_with_variety_food\n",
      "        original_draw = self.ui.draw\n",
      "\n",
      "        def draw_with_variety_food():\n",
      "            original_draw()\n",
      "            self.draw_foods()\n",
      "        self.ui.draw = draw_with_variety_food\n",
      "\n",
      "    def add_variety_food(self):\n",
      "        self.foods.append(Food('normal'))\n",
      "        self.foods.append(Food('big'))\n",
      "\n",
      "    def draw_foods(self):\n",
      "        for food in self.foods:\n",
      "            color = (255, 0, 0) if food.type == 'normal' else (0, 0, 255)\n",
      "            pygame.draw.rect(self.ui.screen, color, pygame.Rect(food.position[0], food.position[1], 10, 10))\n",
      "\n",
      "class Feature(ABC):\n",
      "    @abstractmethod\n",
      "    def modify_game(self, game):\n",
      "        pass\n",
      "\n",
      "    def handle_event(self, event):\n",
      "        pass\n",
      "\n",
      "    def update(self):\n",
      "        pass\n",
      "\n",
      "def main():\n",
      "    game = Snake_Game()\n",
      "    \n",
      "    health_feature = HealthFeature(initial_health=3)\n",
      "    speed_boost_feature = SpeedBoostFeature(speed_increase=5)\n",
      "    enemy_feature = EnemyFeature()\n",
      "    variety_food_feature = VarietyFoodFeature()\n",
      "    \n",
      "    game.add_feature(health_feature)\n",
      "    game.add_feature(speed_boost_feature)\n",
      "    game.add_feature(enemy_feature)\n",
      "    game.add_feature(variety_food_feature)\n",
      "    \n",
      "    game.run()\n",
      "    pygame.quit()\n",
      "\n",
      "if __name__ == \"__main__\":\n",
      "    main()\n",
      "error message:\n",
      " Traceback (most recent call last):\n",
      "  File \"D:\\Python_project\\llm_agent\\output\\snake_game_test\\main.py\", line 481, in <module>\n",
      "    main()\n",
      "  File \"D:\\Python_project\\llm_agent\\output\\snake_game_test\\main.py\", line 474, in main\n",
      "    game.add_feature(enemy_feature)\n",
      "  File \"D:\\Python_project\\llm_agent\\output\\snake_game_test\\main.py\", line 261, in add_feature\n",
      "    feature.modify_game(self)\n",
      "  File \"D:\\Python_project\\llm_agent\\output\\snake_game_test\\main.py\", line 393, in modify_game\n",
      "    self.add_enemy()\n",
      "  File \"D:\\Python_project\\llm_agent\\output\\snake_game_test\\main.py\", line 411, in add_enemy\n",
      "    self.enemies.append(Enemy())\n",
      "NameError: name 'Enemy' is not defined\n",
      "\n",
      "planner error, trying to fix the error 3 time(s).\n",
      "########################################################################################\n",
      "current running code:\n",
      " import pygame\n",
      "import random\n",
      "from abc import ABC, abstractmethod\n",
      "\n",
      "class UI:\n",
      "    def __init__(self, screen, snake, food):\n",
      "        self.screen = screen\n",
      "        self.snake = snake\n",
      "        self.food = food\n",
      "\n",
      "    def draw(self):\n",
      "        # Draw snake\n",
      "        for pos in self.snake.positions:\n",
      "            pygame.draw.rect(self.screen, (0, 255, 0), pygame.Rect(pos[0], pos[1], 10, 10))\n",
      "        # Draw food\n",
      "        pygame.draw.rect(self.screen, (255, 0, 0), pygame.Rect(self.food.position[0], self.food.position[1], 10, 10))\n",
      "\n",
      "class Snake_Game:\n",
      "    def __init__(self):\n",
      "        pygame.init()\n",
      "        self.screen = pygame.display.set_mode((800, 600))\n",
      "        self.clock = pygame.time.Clock()\n",
      "        self.snake = Snake()\n",
      "        self.food = Food()\n",
      "        self.ui = UI(self.screen, self.snake, self.food)\n",
      "        self.running = True\n",
      "        self.features = []\n",
      "        self.score = 0\n",
      "\n",
      "    def add_feature(self, feature):\n",
      "        self.features.append(feature)\n",
      "        feature.modify_game(self)\n",
      "\n",
      "    def run(self):\n",
      "        while self.running:\n",
      "            self.handle_events()\n",
      "            self.update()\n",
      "            self.render()\n",
      "            self.clock.tick(15)\n",
      "\n",
      "    def handle_events(self):\n",
      "        for event in pygame.event.get():\n",
      "            if event.type == pygame.QUIT:\n",
      "                self.running = False\n",
      "            for feature in self.features:\n",
      "                feature.handle_event(event)\n",
      "\n",
      "    def update(self):\n",
      "        self.snake.move()\n",
      "        if self.snake.check_collision(self.food.position):\n",
      "            self.snake.grow()\n",
      "            self.food.reposition()\n",
      "            self.score += 1\n",
      "        for feature in self.features:\n",
      "            feature.update()\n",
      "\n",
      "    def render(self):\n",
      "        self.screen.fill((0, 0, 0))\n",
      "        self.ui.draw()\n",
      "        pygame.display.flip()\n",
      "\n",
      "class Snake:\n",
      "    def __init__(self):\n",
      "        self.positions = [(100, 100), (90, 100), (80, 100)]\n",
      "        self.direction = (10, 0)\n",
      "        self.score = 0\n",
      "\n",
      "    def move(self):\n",
      "        head = self.positions[0]\n",
      "        new_head = (head[0] + self.direction[0], head[1] + self.direction[1])\n",
      "        self.positions = [new_head] + self.positions[:-1]\n",
      "\n",
      "    def grow(self):\n",
      "        self.positions.append(self.positions[-1])\n",
      "        self.score += 1\n",
      "\n",
      "    def check_collision(self, position):\n",
      "        return self.positions[0] == position\n",
      "\n",
      "    def set_direction(self, direction):\n",
      "        self.direction = direction\n",
      "\n",
      "    def check_self_collision(self):\n",
      "        return self.positions[0] in self.positions[1:]\n",
      "\n",
      "    def check_wall_collision(self):\n",
      "        head = self.positions[0]\n",
      "        return not (0 <= head[0] < 800 and 0 <= head[1] < 600)\n",
      "\n",
      "class Food:\n",
      "    def __init__(self, food_type='normal'):\n",
      "        self.type = food_type\n",
      "        self.position = (random.randint(0, 79) * 10, random.randint(0, 59) * 10)\n",
      "\n",
      "    def reposition(self):\n",
      "        self.position = (random.randint(0, 79) * 10, random.randint(0, 59) * 10)\n",
      "\n",
      "class HealthFeature(Feature):\n",
      "    def __init__(self, initial_health=3):\n",
      "        self.health = initial_health\n",
      "\n",
      "    def modify_game(self, game):\n",
      "        self.game = game\n",
      "        self.snake = game.snake\n",
      "        self.ui = game.ui\n",
      "        original_draw = self.ui.draw\n",
      "\n",
      "        def draw_with_health():\n",
      "            original_draw()\n",
      "            self.draw_health()\n",
      "        self.ui.draw = draw_with_health\n",
      "        original_update = game.update\n",
      "\n",
      "        def update_with_health():\n",
      "            if self.snake.check_self_collision() or self.snake.check_wall_collision():\n",
      "                self.health -= 1\n",
      "                if self.health <= 0:\n",
      "                    game.running = False\n",
      "                else:\n",
      "                    self.respawn_snake()\n",
      "            else:\n",
      "                original_update()\n",
      "        game.update = update_with_health\n",
      "\n",
      "    def draw_health(self):\n",
      "        font = pygame.font.Font(None, 36)\n",
      "        health_text = font.render(f'Health: {self.health}', True, (255, 255, 255))\n",
      "        self.ui.screen.blit(health_text, (10, 10))\n",
      "\n",
      "    def respawn_snake(self):\n",
      "        self.snake.positions = [(100, 100), (90, 100), (80, 100)]\n",
      "        self.snake.direction = (10, 0)\n",
      "\n",
      "class SpeedBoostFeature(Feature):\n",
      "    def __init__(self, speed_increase=10):\n",
      "        self.speed_increase = speed_increase\n",
      "        self.original_speed = 15\n",
      "        self.boost_active = False\n",
      "\n",
      "    def modify_game(self, game):\n",
      "        self.game = game\n",
      "        original_handle_events = game.handle_events\n",
      "\n",
      "        def handle_events_with_speed_boost():\n",
      "            original_handle_events()\n",
      "            keys = pygame.key.get_pressed()\n",
      "            self.boost_active = keys[pygame.K_SPACE]\n",
      "        game.handle_events = handle_events_with_speed_boost\n",
      "        original_update = game.update\n",
      "\n",
      "        def update_with_speed_boost():\n",
      "            game.clock.tick(self.original_speed + self.speed_increase if self.boost_active else self.original_speed)\n",
      "            original_update()\n",
      "        game.update = update_with_speed_boost\n",
      "\n",
      "class Enemy:\n",
      "    def __init__(self):\n",
      "        self.position = (random.randint(0, 79) * 10, random.randint(0, 59) * 10)\n",
      "        self.direction = (10, 0)  # Can be modified for more movement patterns\n",
      "\n",
      "    def move(self):\n",
      "        head = self.position\n",
      "        new_position = (head[0] + self.direction[0], head[1] + self.direction[1])\n",
      "        # Move to new position (for demonstration. Add boundary checks here).\n",
      "        self.position = new_position\n",
      "\n",
      "class EnemyFeature(Feature):\n",
      "    def __init__(self):\n",
      "        self.enemies = []\n",
      "\n",
      "    def modify_game(self, game):\n",
      "        self.game = game\n",
      "        self.snake = game.snake\n",
      "        self.ui = game.ui\n",
      "        self.add_enemy()\n",
      "        original_update = game.update\n",
      "\n",
      "        def update_with_enemies():\n",
      "            for enemy in self.enemies:\n",
      "                if self.snake.check_collision(enemy.position):\n",
      "                    game.running = False\n",
      "                enemy.move()\n",
      "            original_update()\n",
      "        game.update = update_with_enemies\n",
      "        original_draw = self.ui.draw\n",
      "\n",
      "        def draw_with_enemies():\n",
      "            original_draw()\n",
      "            self.draw_enemies()\n",
      "        self.ui.draw = draw_with_enemies\n",
      "\n",
      "    def add_enemy(self):\n",
      "        self.enemies.append(Enemy())\n",
      "\n",
      "    def draw_enemies(self):\n",
      "        for enemy in self.enemies:\n",
      "            pygame.draw.rect(self.ui.screen, (255, 255, 0), pygame.Rect(enemy.position[0], enemy.position[1], 10, 10))\n",
      "\n",
      "class VarietyFoodFeature(Feature):\n",
      "    def __init__(self):\n",
      "        self.foods = []\n",
      "\n",
      "    def modify_game(self, game):\n",
      "        self.game = game\n",
      "        self.snake = game.snake\n",
      "        self.ui = game.ui\n",
      "        self.add_variety_food()\n",
      "        original_update = game.update\n",
      "\n",
      "        def update_with_variety_food():\n",
      "            for food in self.foods:\n",
      "                if self.snake.check_collision(food.position):\n",
      "                    self.snake.grow()\n",
      "                    if food.type == 'big':\n",
      "                        self.snake.grow()\n",
      "                    food.reposition()\n",
      "            original_update()\n",
      "        game.update = update_with_variety_food\n",
      "        original_draw = self.ui.draw\n",
      "\n",
      "        def draw_with_variety_food():\n",
      "            original_draw()\n",
      "            self.draw_foods()\n",
      "        self.ui.draw = draw_with_variety_food\n",
      "\n",
      "    def add_variety_food(self):\n",
      "        self.foods.append(Food('normal'))\n",
      "        self.foods.append(Food('big'))\n",
      "\n",
      "    def draw_foods(self):\n",
      "        for food in self.foods:\n",
      "            color = (255, 0, 0) if food.type == 'normal' else (0, 0, 255)\n",
      "            pygame.draw.rect(self.ui.screen, color, pygame.Rect(food.position[0], food.position[1], 10, 10))\n",
      "\n",
      "class Feature(ABC):\n",
      "    @abstractmethod\n",
      "    def modify_game(self, game):\n",
      "        pass\n",
      "\n",
      "    def handle_event(self, event):\n",
      "        pass\n",
      "\n",
      "    def update(self):\n",
      "        pass\n",
      "\n",
      "def main():\n",
      "    game = Snake_Game()\n",
      "    \n",
      "    health_feature = HealthFeature(initial_health=3)\n",
      "    speed_boost_feature = SpeedBoostFeature(speed_increase=5)\n",
      "    enemy_feature = EnemyFeature()\n",
      "    variety_food_feature = VarietyFoodFeature()\n",
      "    \n",
      "    game.add_feature(health_feature)\n",
      "    game.add_feature(speed_boost_feature)\n",
      "    game.add_feature(enemy_feature)\n",
      "    game.add_feature(variety_food_feature)\n",
      "    \n",
      "    game.run()\n",
      "    pygame.quit()\n",
      "\n",
      "if __name__ == \"__main__\":\n",
      "    main()\n",
      "debug_1\n",
      "debug_2\n",
      "debug_3\n",
      "########################################################################################\n",
      "testing the code.\n",
      "planner finished successfully.\n",
      "final running code:\n",
      "import pygame\n",
      "import random\n",
      "from abc import ABC, abstractmethod\n",
      "\n",
      "class UI:\n",
      "    def __init__(self, screen, snake, food):\n",
      "        self.screen = screen\n",
      "        self.snake = snake\n",
      "        self.food = food\n",
      "\n",
      "    def draw(self):\n",
      "        # Draw snake\n",
      "        for pos in self.snake.positions:\n",
      "            pygame.draw.rect(self.screen, (0, 255, 0), pygame.Rect(pos[0], pos[1], 10, 10))\n",
      "        # Draw food\n",
      "        pygame.draw.rect(self.screen, (255, 0, 0), pygame.Rect(self.food.position[0], self.food.position[1], 10, 10))\n",
      "\n",
      "class Snake_Game:\n",
      "    def __init__(self):\n",
      "        pygame.init()\n",
      "        self.screen = pygame.display.set_mode((800, 600))\n",
      "        self.clock = pygame.time.Clock()\n",
      "        self.snake = Snake()\n",
      "        self.food = Food()\n",
      "        self.ui = UI(self.screen, self.snake, self.food)\n",
      "        self.running = True\n",
      "        self.features = []\n",
      "        self.score = 0\n",
      "\n",
      "    def add_feature(self, feature):\n",
      "        self.features.append(feature)\n",
      "        feature.modify_game(self)\n",
      "\n",
      "    def run(self):\n",
      "        while self.running:\n",
      "            self.handle_events()\n",
      "            self.update()\n",
      "            self.render()\n",
      "            self.clock.tick(15)\n",
      "\n",
      "    def handle_events(self):\n",
      "        for event in pygame.event.get():\n",
      "            if event.type == pygame.QUIT:\n",
      "                self.running = False\n",
      "            for feature in self.features:\n",
      "                feature.handle_event(event)\n",
      "\n",
      "    def update(self):\n",
      "        self.snake.move()\n",
      "        if self.snake.check_collision(self.food.position):\n",
      "            self.snake.grow()\n",
      "            self.food.reposition()\n",
      "            self.score += 1\n",
      "        for feature in self.features:\n",
      "            feature.update()\n",
      "\n",
      "    def render(self):\n",
      "        self.screen.fill((0, 0, 0))\n",
      "        self.ui.draw()\n",
      "        pygame.display.flip()\n",
      "\n",
      "class Snake:\n",
      "    def __init__(self):\n",
      "        self.positions = [(100, 100), (90, 100), (80, 100)]\n",
      "        self.direction = (10, 0)\n",
      "        self.score = 0\n",
      "\n",
      "    def move(self):\n",
      "        head = self.positions[0]\n",
      "        new_head = (head[0] + self.direction[0], head[1] + self.direction[1])\n",
      "        self.positions = [new_head] + self.positions[:-1]\n",
      "\n",
      "    def grow(self):\n",
      "        self.positions.append(self.positions[-1])\n",
      "        self.score += 1\n",
      "\n",
      "    def check_collision(self, position):\n",
      "        return self.positions[0] == position\n",
      "\n",
      "    def set_direction(self, direction):\n",
      "        self.direction = direction\n",
      "\n",
      "    def check_self_collision(self):\n",
      "        return self.positions[0] in self.positions[1:]\n",
      "\n",
      "    def check_wall_collision(self):\n",
      "        head = self.positions[0]\n",
      "        return not (0 <= head[0] < 800 and 0 <= head[1] < 600)\n",
      "\n",
      "class Food:\n",
      "    def __init__(self, food_type='normal'):\n",
      "        self.type = food_type\n",
      "        self.position = (random.randint(0, 79) * 10, random.randint(0, 59) * 10)\n",
      "\n",
      "    def reposition(self):\n",
      "        self.position = (random.randint(0, 79) * 10, random.randint(0, 59) * 10)\n",
      "\n",
      "class HealthFeature(Feature):\n",
      "    def __init__(self, initial_health=3):\n",
      "        self.health = initial_health\n",
      "\n",
      "    def modify_game(self, game):\n",
      "        self.game = game\n",
      "        self.snake = game.snake\n",
      "        self.ui = game.ui\n",
      "        original_draw = self.ui.draw\n",
      "\n",
      "        def draw_with_health():\n",
      "            original_draw()\n",
      "            self.draw_health()\n",
      "        self.ui.draw = draw_with_health\n",
      "        original_update = game.update\n",
      "\n",
      "        def update_with_health():\n",
      "            if self.snake.check_self_collision() or self.snake.check_wall_collision():\n",
      "                self.health -= 1\n",
      "                if self.health <= 0:\n",
      "                    game.running = False\n",
      "                else:\n",
      "                    self.respawn_snake()\n",
      "            else:\n",
      "                original_update()\n",
      "        game.update = update_with_health\n",
      "\n",
      "    def draw_health(self):\n",
      "        font = pygame.font.Font(None, 36)\n",
      "        health_text = font.render(f'Health: {self.health}', True, (255, 255, 255))\n",
      "        self.ui.screen.blit(health_text, (10, 10))\n",
      "\n",
      "    def respawn_snake(self):\n",
      "        self.snake.positions = [(100, 100), (90, 100), (80, 100)]\n",
      "        self.snake.direction = (10, 0)\n",
      "\n",
      "class SpeedBoostFeature(Feature):\n",
      "    def __init__(self, speed_increase=10):\n",
      "        self.speed_increase = speed_increase\n",
      "        self.original_speed = 15\n",
      "        self.boost_active = False\n",
      "\n",
      "    def modify_game(self, game):\n",
      "        self.game = game\n",
      "        original_handle_events = game.handle_events\n",
      "\n",
      "        def handle_events_with_speed_boost():\n",
      "            original_handle_events()\n",
      "            keys = pygame.key.get_pressed()\n",
      "            self.boost_active = keys[pygame.K_SPACE]\n",
      "        game.handle_events = handle_events_with_speed_boost\n",
      "        original_update = game.update\n",
      "\n",
      "        def update_with_speed_boost():\n",
      "            game.clock.tick(self.original_speed + self.speed_increase if self.boost_active else self.original_speed)\n",
      "            original_update()\n",
      "        game.update = update_with_speed_boost\n",
      "\n",
      "class Enemy:\n",
      "    def __init__(self):\n",
      "        self.position = (random.randint(0, 79) * 10, random.randint(0, 59) * 10)\n",
      "        self.direction = (10, 0)  # Can be modified for more movement patterns\n",
      "\n",
      "    def move(self):\n",
      "        head = self.position\n",
      "        new_position = (head[0] + self.direction[0], head[1] + self.direction[1])\n",
      "        # Move to new position (for demonstration. Add boundary checks here).\n",
      "        self.position = new_position\n",
      "\n",
      "class EnemyFeature(Feature):\n",
      "    def __init__(self):\n",
      "        self.enemies = []\n",
      "\n",
      "    def modify_game(self, game):\n",
      "        self.game = game\n",
      "        self.snake = game.snake\n",
      "        self.ui = game.ui\n",
      "        self.add_enemy()\n",
      "        original_update = game.update\n",
      "\n",
      "        def update_with_enemies():\n",
      "            for enemy in self.enemies:\n",
      "                if self.snake.check_collision(enemy.position):\n",
      "                    game.running = False\n",
      "                enemy.move()\n",
      "            original_update()\n",
      "        game.update = update_with_enemies\n",
      "        original_draw = self.ui.draw\n",
      "\n",
      "        def draw_with_enemies():\n",
      "            original_draw()\n",
      "            self.draw_enemies()\n",
      "        self.ui.draw = draw_with_enemies\n",
      "\n",
      "    def add_enemy(self):\n",
      "        self.enemies.append(Enemy())\n",
      "\n",
      "    def draw_enemies(self):\n",
      "        for enemy in self.enemies:\n",
      "            pygame.draw.rect(self.ui.screen, (255, 255, 0), pygame.Rect(enemy.position[0], enemy.position[1], 10, 10))\n",
      "\n",
      "class VarietyFoodFeature(Feature):\n",
      "    def __init__(self):\n",
      "        self.foods = []\n",
      "\n",
      "    def modify_game(self, game):\n",
      "        self.game = game\n",
      "        self.snake = game.snake\n",
      "        self.ui = game.ui\n",
      "        self.add_variety_food()\n",
      "        original_update = game.update\n",
      "\n",
      "        def update_with_variety_food():\n",
      "            for food in self.foods:\n",
      "                if self.snake.check_collision(food.position):\n",
      "                    self.snake.grow()\n",
      "                    if food.type == 'big':\n",
      "                        self.snake.grow()\n",
      "                    food.reposition()\n",
      "            original_update()\n",
      "        game.update = update_with_variety_food\n",
      "        original_draw = self.ui.draw\n",
      "\n",
      "        def draw_with_variety_food():\n",
      "            original_draw()\n",
      "            self.draw_foods()\n",
      "        self.ui.draw = draw_with_variety_food\n",
      "\n",
      "    def add_variety_food(self):\n",
      "        self.foods.append(Food('normal'))\n",
      "        self.foods.append(Food('big'))\n",
      "\n",
      "    def draw_foods(self):\n",
      "        for food in self.foods:\n",
      "            color = (255, 0, 0) if food.type == 'normal' else (0, 0, 255)\n",
      "            pygame.draw.rect(self.ui.screen, color, pygame.Rect(food.position[0], food.position[1], 10, 10))\n",
      "\n",
      "class Feature(ABC):\n",
      "    @abstractmethod\n",
      "    def modify_game(self, game):\n",
      "        pass\n",
      "\n",
      "    def handle_event(self, event):\n",
      "        pass\n",
      "\n",
      "    def update(self):\n",
      "        pass\n",
      "\n",
      "def main():\n",
      "    game = Snake_Game()\n",
      "    \n",
      "    health_feature = HealthFeature(initial_health=3)\n",
      "    speed_boost_feature = SpeedBoostFeature(speed_increase=5)\n",
      "    enemy_feature = EnemyFeature()\n",
      "    variety_food_feature = VarietyFoodFeature()\n",
      "    \n",
      "    game.add_feature(health_feature)\n",
      "    game.add_feature(speed_boost_feature)\n",
      "    game.add_feature(enemy_feature)\n",
      "    game.add_feature(variety_food_feature)\n",
      "    \n",
      "    game.run()\n",
      "    pygame.quit()\n",
      "\n",
      "if __name__ == \"__main__\":\n",
      "    main()\n",
      "code interpreter response:\n",
      "The provided code implements a simple 2D \"Snake\" game using the Pygame library in Python. Here's a brief description of its components and functionality:\n",
      "\n",
      "1. **Game Structure**: The main class is `Snake_Game`, which handles game initialization, event handling, game updates, and rendering. The game features a window of size 800x600 pixels.\n",
      "\n",
      "2. **Snake Mechanics**: The `Snake` class represents the snake in the game, managing its position as a list of tuples, direction of movement, growth when eating food, and collision detection with food or walls.\n",
      "\n",
      "3. **Food**: The `Food` class generates food items at random positions on the grid. The snake can grow by eating this food, and the food's position is updated upon being consumed.\n",
      "\n",
      "4. **User Interface (UI)**: The `UI` class is responsible for drawing the snake and food onto the screen. It draws the snake in green and the food in red.\n",
      "\n",
      "5. **Game Features**: The game supports multiple features through an extensible `Feature` class:\n",
      "   - `HealthFeature`: Adds health points to the game, allowing the snake to respawn upon self-collision or wall collision until health runs out.\n",
      "   - `SpeedBoostFeature`: Allows the player to temporarily increase the snake's speed by pressing the spacebar.\n",
      "   - `EnemyFeature`: Introduces enemies that can cause the game to end when colliding with the snake.\n",
      "   - `VarietyFoodFeature`: Gives the snake an additional type of food that can either add extra growth when eaten.\n",
      "\n",
      "6. **Event Handling**: The game listens for events like quitting and enables features to respond to user inputs accordingly.\n",
      "\n",
      "7. **Game Loop**: The game runs in a loop that updates the state, checks collisions, and redraws the screen at a controlled frame rate.\n",
      "\n",
      "8. **Main Function**: The `main()` function initializes the game, adds the various features, and starts the game loop.\n",
      "\n",
      "Overall, this code provides a well-structured framework for a Snake game, making it easy to add or modify gameplay features using an abstract base class for game features.\n",
      "similarity score: 0.6252180384117302\n",
      "test functions:\n",
      "[]\n",
      "True\n",
      "The code's running time exceeded the maximum runtime\n"
     ]
    }
   ],
   "source": [
    "# test the Planner_agent\n",
    "\n",
    "def test_function_for_snake_game(main_file_path):\n",
    "    # check if the code contains ControlFeature()\n",
    "    with open(main_file_path, 'r') as file:\n",
    "        code = file.read()\n",
    "    \n",
    "    if 'ControlFeature()' in code:\n",
    "        if \"add_feature(ControlFeature)\" in code:\n",
    "            return True, 'ControlFeature() is used'\n",
    "        else:\n",
    "            return False, 'ControlFeature() is used, but not enabled.'\n",
    "    else:\n",
    "        return False, 'ControlFeature() is not used, Player cannot control the snake.'\n",
    "\n",
    "task_text = \"make a snake game that the snake have health and is able to accelerate, the game have enemy and different type of foods\"\n",
    "py_paths = [r'D:\\Python_project\\snack_game_test\\snack_test04_5.py']\n",
    "use_py_files = False\n",
    "test_functions=[test_function_for_snake_game]\n",
    "\n",
    "api_key = 'sk-rifpc-2Gg7xjJ4qrwzWY7hUhZKT3BlbkFJBkz9CHkx9LkVsSciz9Tg'\n",
    "max_tokens = 4096\n",
    "store_path = r'D:\\Python_project\\llm_agent\\output\\snake_game_test'\n",
    "model = 'gpt-4o-mini'\n",
    "embedding_model = \"text-embedding-3-small\"\n",
    "json_file_paths = [r\"D:\\Python_project\\llm_agent\\json_files\\eval_json_file.json\", r\"D:\\Python_project\\llm_agent\\json_files\\pacman_game.json\", r\"D:\\Python_project\\llm_agent\\json_files\\snack_game.json\"]\n",
    "RAG_config = {\n",
    "    \"target\": \"both\",\n",
    "    \"target_type\": [\"function\", \"class\", \"evaluation\"],\n",
    "    \"weight\": 0.5,\n",
    "    \"keywords_score\": 1,\n",
    "    \"min_threshold\": 0.5,\n",
    "    \"min_num\": 5,\n",
    "    \"keep_scores\": True\n",
    "}\n",
    "RAG_method = \"AGENT\"\n",
    "\n",
    "use_code_interpreter = True\n",
    "\n",
    "max_trying_num = 3\n",
    "max_runtime = 10\n",
    "\n",
    "print_prompt = False\n",
    "\n",
    "planner = Planner_agent(api_key, \n",
    "                        max_tokens=max_tokens, \n",
    "                        store_path=store_path, \n",
    "                        model=model, \n",
    "                        embedding_model=embedding_model, \n",
    "                        json_file_paths=json_file_paths, \n",
    "                        RAG_config=RAG_config)\n",
    "\n",
    "result, message = await planner.run_planner(task_text, py_paths, max_runtime = max_runtime, max_error_num = max_trying_num, use_py_files=use_py_files, test_functions=test_functions,RAG_method=RAG_method, print_prompt=print_prompt, use_code_interpreter=use_code_interpreter)\n",
    "\n",
    "print(result)\n",
    "print(message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
