{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LOAD DATA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import asyncio\n",
    "import random\n",
    "import pandas as pd\n",
    "from typing import List, Dict, Tuple\n",
    "import openai\n",
    "\n",
    "# Assume RewardModel is already defined as per the user's specification\n",
    "class RewardModel:\n",
    "    def __init__(self, data: pd.DataFrame):\n",
    "        self.data = data\n",
    "        self.reward_model = None\n",
    "\n",
    "    def process_dataset(self):\n",
    "        # Implement dataset preprocessing steps\n",
    "        pass\n",
    "\n",
    "    def train_reward_model(self):\n",
    "        # Implement training of the reward model\n",
    "        pass\n",
    "\n",
    "    def generate_reward(self, input_code: str) -> float:\n",
    "        # Implement reward generation for a single code\n",
    "        return self.reward_model.predict([input_code])[0]\n",
    "\n",
    "    def generate_batch_reward(self, code_list: List[str]) -> List[float]:\n",
    "        # Implement batch reward generation\n",
    "        return self.reward_model.predict(code_list)\n",
    "\n",
    "\n",
    "class LLMAgentFramework:\n",
    "    def __init__(\n",
    "        self,\n",
    "        data: pd.DataFrame,\n",
    "        input_text: str,\n",
    "        llm_params: Dict,\n",
    "        num_codes_per_iteration: int,\n",
    "        max_iterations: int,\n",
    "        score_threshold: float,\n",
    "        score_weights: Dict[str, float],\n",
    "        openai_api_key: str,\n",
    "    ):\n",
    "        self.data = data\n",
    "        self.input_text = input_text\n",
    "        self.llm_params = llm_params\n",
    "        self.num_codes = num_codes_per_iteration\n",
    "        self.max_iterations = max_iterations\n",
    "        self.threshold = score_threshold\n",
    "        self.weights = score_weights\n",
    "        self.reward_model = RewardModel(data)\n",
    "        self.best_code = None\n",
    "        self.best_score = -float('inf')\n",
    "        self.failed_test_cases = []\n",
    "        openai.api_key = openai_api_key\n",
    "\n",
    "    async def preprocess(self):\n",
    "        print(\"Preprocessing dataset and training reward model...\")\n",
    "        self.reward_model.process_dataset()\n",
    "        self.reward_model.train_reward_model()\n",
    "        print(\"Preprocessing and training completed.\")\n",
    "\n",
    "    async def split_requirements(self) -> List[str]:\n",
    "        prompt = f\"\"\"\n",
    "        You are a professional Python algorithm engineer. \n",
    "        Based on the following task description, split the requirements into separate parts. \n",
    "        If the task is simple, return a single requirement.\n",
    "\n",
    "        Task Description:\n",
    "        {self.input_text}\n",
    "\n",
    "        Requirements:\n",
    "        \"\"\"\n",
    "        response = await self.async_generate_text(prompt)\n",
    "        requirements = self.extract_requirements(response)\n",
    "        print(f\"Split into requirements: {requirements}\")\n",
    "        return requirements\n",
    "\n",
    "    async def generate_test_cases(self, requirement: str, sampled_codes: List[str]) -> List[Dict]:\n",
    "        # Prepare context by sampling some code snippets\n",
    "        context = \"\\n\\n\".join(random.sample(sampled_codes, min(3, len(sampled_codes))))\n",
    "        prompt = f\"\"\"\n",
    "        You are a professional Python algorithm engineer.\n",
    "        Based on the following requirement and code examples, generate a set of test cases.\n",
    "        Each test case should include:\n",
    "        1. Test case data\n",
    "        2. A Python code snippet that uses the test case data to output True if the test passes, or False otherwise.\n",
    "\n",
    "        Requirement:\n",
    "        {requirement}\n",
    "\n",
    "        Code Examples:\n",
    "        {context}\n",
    "\n",
    "        Test Cases:\n",
    "        \"\"\"\n",
    "        response = await self.async_generate_text(prompt)\n",
    "        test_cases = self.extract_test_cases(response)\n",
    "        print(f\"Generated {len(test_cases)} test cases for requirement: {requirement}\")\n",
    "        return test_cases\n",
    "\n",
    "    async def validate_test_cases(self, test_cases: List[Dict]) -> List[Dict]:\n",
    "        valid_test_cases = []\n",
    "        for test_case in test_cases:\n",
    "            pass_count = 0\n",
    "            for code in self.data['code']:\n",
    "                try:\n",
    "                    # Dynamically execute the code snippet to test the test case\n",
    "                    exec_globals = {}\n",
    "                    exec(code, exec_globals)\n",
    "                    test_func = exec_globals.get('test_function')  # Assume each code defines a test_function\n",
    "                    if test_func and test_func(test_case['data']):\n",
    "                        pass_count += 1\n",
    "                except Exception:\n",
    "                    continue\n",
    "            pass_rate = pass_count / len(self.data)\n",
    "            if pass_rate >= 0.8:  # Arbitrary threshold, can be adjusted\n",
    "                valid_test_cases.append(test_case)\n",
    "        print(f\"Validated test cases: {len(valid_test_cases)} out of {len(test_cases)}\")\n",
    "        return valid_test_cases\n",
    "\n",
    "    async def generate_codes(self, requirements: List[str], test_cases: List[Dict]) -> List[str]:\n",
    "        # Combine requirements and test cases into a prompt\n",
    "        test_case_descriptions = \"\\n\".join(\n",
    "            [f\"Test Case Data: {tc['data']}\\nSnippet: {tc['snippet']}\" for tc in test_cases]\n",
    "        )\n",
    "        prompt = f\"\"\"\n",
    "        You are a professional Python algorithm engineer.\n",
    "        Based on the following requirements and test cases, generate {self.num_codes} Python code implementations.\n",
    "        Ensure that the code is compatible with the test cases.\n",
    "\n",
    "        Requirements:\n",
    "        {', '.join(requirements)}\n",
    "\n",
    "        Test Cases:\n",
    "        {test_case_descriptions}\n",
    "\n",
    "        Generated Codes:\n",
    "        \"\"\"\n",
    "        response = await self.async_generate_text(prompt)\n",
    "        codes = self.extract_generated_codes(response, self.num_codes)\n",
    "        print(f\"Generated {len(codes)} candidate codes.\")\n",
    "        return codes\n",
    "\n",
    "    async def evaluate_codes(self, codes: List[str], test_cases: List[Dict]) -> List[Tuple[str, float, float]]:\n",
    "        # Evaluate each code against the test cases\n",
    "        code_rewards = self.reward_model.generate_batch_reward(codes)\n",
    "        code_scores = []\n",
    "        for idx, code in enumerate(codes):\n",
    "            pass_count = 0\n",
    "            for tc in test_cases:\n",
    "                try:\n",
    "                    exec_globals = {}\n",
    "                    exec(code, exec_globals)\n",
    "                    func = exec_globals.get('generated_function')  # Assume generated code defines this function\n",
    "                    if func and func(tc['data']):\n",
    "                        # Execute the snippet to get True/False\n",
    "                        snippet = tc['snippet']\n",
    "                        snippet_globals = {'generated_function': func}\n",
    "                        result = eval(snippet, snippet_globals)\n",
    "                        if result:\n",
    "                            pass_count += 1\n",
    "                except Exception:\n",
    "                    continue\n",
    "            test_case_pass_rate = pass_count / len(test_cases) if test_cases else 0\n",
    "            reward = code_rewards[idx]\n",
    "            final_score = self.weights.get(\"test_case\", 0) * test_case_pass_rate + \\\n",
    "                          self.weights.get(\"reward_model\", 0) * reward\n",
    "            code_scores.append((code, test_case_pass_rate, reward))\n",
    "            if final_score > self.best_score:\n",
    "                self.best_score = final_score\n",
    "                self.best_code = code\n",
    "        print(f\"Evaluated {len(codes)} codes.\")\n",
    "        return code_scores\n",
    "\n",
    "    async def iterate(self, requirements: List[str], test_cases: List[Dict]):\n",
    "        for iteration in range(1, self.max_iterations + 1):\n",
    "            print(f\"--- Iteration {iteration} ---\")\n",
    "            codes = await self.generate_codes(requirements, test_cases)\n",
    "            evaluated = await self.evaluate_codes(codes, test_cases)\n",
    "            print(f\"Best score this iteration: {self.best_score}\")\n",
    "            if self.best_score >= self.threshold:\n",
    "                print(\"Threshold met. Stopping iterations.\")\n",
    "                break\n",
    "            # Incorporate failed test cases for next iteration\n",
    "            self.failed_test_cases = [\n",
    "                tc for tc in test_cases if not self.evaluate_single_code(self.best_code, tc)\n",
    "            ]\n",
    "            if not self.failed_test_cases:\n",
    "                print(\"No failed test cases. Stopping iterations.\")\n",
    "                break\n",
    "            # Optionally, refine test cases or requirements based on failed cases\n",
    "        print(\"Iteration complete.\")\n",
    "\n",
    "    def evaluate_single_code(self, code: str, test_case: Dict) -> bool:\n",
    "        try:\n",
    "            exec_globals = {}\n",
    "            exec(code, exec_globals)\n",
    "            func = exec_globals.get('generated_function')  # Assume generated code defines this function\n",
    "            if func:\n",
    "                snippet = test_case['snippet']\n",
    "                snippet_globals = {'generated_function': func}\n",
    "                return eval(snippet, snippet_globals)\n",
    "        except Exception:\n",
    "            return False\n",
    "        return False\n",
    "\n",
    "    async def run(self):\n",
    "        await self.preprocess()\n",
    "        requirements = await self.split_requirements()\n",
    "        sampled_codes = self.data['code'].sample(n=10).tolist()  # Sample 10 codes for context\n",
    "        test_cases = []\n",
    "        for req in requirements:\n",
    "            generated = await self.generate_test_cases(req, sampled_codes)\n",
    "            valid = await self.validate_test_cases(generated)\n",
    "            test_cases.extend(valid)\n",
    "        await self.iterate(requirements, test_cases)\n",
    "        print(f\"Best Code:\\n{self.best_code}\\nWith Score: {self.best_score}\")\n",
    "\n",
    "    async def async_generate_text(self, prompt: str) -> str:\n",
    "        response = await openai.ChatCompletion.acreate(\n",
    "            model=\"gpt-4\",\n",
    "            messages=[{\"role\": \"user\", \"content\": prompt}],\n",
    "            **self.llm_params\n",
    "        )\n",
    "        return response['choices'][0]['message']['content'].strip()\n",
    "\n",
    "    def extract_requirements(self, response: str) -> List[str]:\n",
    "        # Implement extraction logic, e.g., splitting by bullet points or numbering\n",
    "        lines = response.split('\\n')\n",
    "        requirements = [line.strip('- ').strip() for line in lines if line.strip()]\n",
    "        return requirements\n",
    "\n",
    "    def extract_test_cases(self, response: str) -> List[Dict]:\n",
    "        # Parse the response to extract test cases\n",
    "        # Assume each test case is separated by a line\n",
    "        test_cases = []\n",
    "        lines = response.split('\\n')\n",
    "        current_tc = {}\n",
    "        for line in lines:\n",
    "            if line.startswith(\"Test case data:\") or line.startswith(\"Test case data\"):\n",
    "                current_tc['data'] = line.split(':', 1)[1].strip()\n",
    "            elif line.startswith(\"Snippet:\") or line.startswith(\"Snippet\"):\n",
    "                current_tc['snippet'] = line.split(':', 1)[1].strip()\n",
    "                if 'data' in current_tc and 'snippet' in current_tc:\n",
    "                    test_cases.append(current_tc)\n",
    "                    current_tc = {}\n",
    "        return test_cases\n",
    "\n",
    "    def extract_generated_codes(self, response: str, num_codes: int) -> List[str]:\n",
    "        # Extract code blocks from the response\n",
    "        codes = []\n",
    "        code_blocks = response.split(\"```python\")\n",
    "        for block in code_blocks[1:]:\n",
    "            code = block.split(\"```\")[0].strip()\n",
    "            codes.append(code)\n",
    "            if len(codes) >= num_codes:\n",
    "                break\n",
    "        return codes\n",
    "\n",
    "\n",
    "# Example usage\n",
    "if __name__ == \"__main__\":\n",
    "    import pandas as pd\n",
    "\n",
    "    # Sample dataset\n",
    "    data = pd.DataFrame({\n",
    "        'code': [\n",
    "            \"\"\"\n",
    "def generated_function(x):\n",
    "    return x > 0\n",
    "\n",
    "def test_function(data):\n",
    "    return generated_function(data) == (data > 0)\n",
    "\"\"\",\n",
    "            # Add more code samples as needed\n",
    "        ],\n",
    "        'one_hot_state': [[]],  # Placeholder\n",
    "        'one_hot_action': [[]],  # Placeholder\n",
    "    })\n",
    "\n",
    "    input_text = \"Create a function that determines if a number is positive.\"\n",
    "\n",
    "    llm_generation_params = {\n",
    "        \"temperature\": 0.7,\n",
    "        \"max_tokens\": 150,\n",
    "        \"n\": 1,\n",
    "        \"stop\": None,\n",
    "    }\n",
    "\n",
    "    framework = LLMAgentFramework(\n",
    "        data=data,\n",
    "        input_text=input_text,\n",
    "        llm_params=llm_generation_params,\n",
    "        num_codes_per_iteration=5,\n",
    "        max_iterations=10,\n",
    "        score_threshold=0.95,\n",
    "        score_weights={\"test_case\": 0.5, \"reward_model\": 0.5},\n",
    "        openai_api_key=\"your-openai-api-key\",\n",
    "    )\n",
    "\n",
    "    asyncio.run(framework.run())\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
