{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6a561849",
   "metadata": {},
   "outputs": [],
   "source": [
    "from LCDP import LCDP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5d74927d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "\n",
    "def append_dict_to_json_list_file(dictionary_to_append, json_file_path):\n",
    "    # Validate input type\n",
    "    if not isinstance(dictionary_to_append, dict):\n",
    "        raise TypeError(\"The data to append must be a dictionary.\")\n",
    "\n",
    "    data_list = []\n",
    "\n",
    "    # Check if the file exists and is not empty\n",
    "    if os.path.exists(json_file_path) and os.path.getsize(json_file_path) > 0:\n",
    "        try:\n",
    "            with open(json_file_path, 'r', encoding='utf-8') as f:\n",
    "                # Attempt to load existing data\n",
    "                existing_data = json.load(f)\n",
    "                # Ensure it's a list, otherwise start fresh (or raise error)\n",
    "                if isinstance(existing_data, list):\n",
    "                    data_list = existing_data\n",
    "                else:\n",
    "                    # If it's not a list, we'll overwrite with a new list.\n",
    "                    # Alternatively, you could raise an error here if strict format is required.\n",
    "                    print(f\"Warning: File '{json_file_path}' did not contain a list. \"\n",
    "                          \"It will be overwritten with a new list.\")\n",
    "                    data_list = [] \n",
    "        except json.JSONDecodeError as e:\n",
    "            # If JSON is malformed, print a warning and overwrite with a new list.\n",
    "            print(f\"Warning: File '{json_file_path}' contained invalid JSON. \"\n",
    "                  f\"It will be overwritten. Error: {e}\")\n",
    "            data_list = []\n",
    "        except IOError as e:\n",
    "            # Catch other potential file reading issues\n",
    "            raise IOError(f\"Could not read from file '{json_file_path}': {e}\")\n",
    "    elif os.path.exists(json_file_path) and os.path.getsize(json_file_path) == 0:\n",
    "        # File exists but is empty. Treat as a new list.\n",
    "        print(f\"Info: File '{json_file_path}' is empty. Initializing a new list.\")\n",
    "        data_list = []\n",
    "\n",
    "    # Append the new dictionary to the list\n",
    "    data_list.append(dictionary_to_append)\n",
    "\n",
    "    # Write the updated list back to the JSON file\n",
    "    try:\n",
    "        with open(json_file_path, 'w', encoding='utf-8') as f:\n",
    "            json.dump(data_list, f, indent=4, ensure_ascii=False)\n",
    "    except IOError as e:\n",
    "        # Catch potential file writing issues\n",
    "        raise IOError(f\"Could not write to file '{json_file_path}': {e}\")\n",
    "\n",
    "def load_dicts_from_json_file(json_file_path):\n",
    "    if not os.path.exists(json_file_path) or os.path.getsize(json_file_path) == 0:\n",
    "        # File doesn't exist or is empty\n",
    "        return []\n",
    "\n",
    "    try:\n",
    "        with open(json_file_path, 'r', encoding='utf-8') as f:\n",
    "            data = json.load(f)\n",
    "            if isinstance(data, list):\n",
    "                return data\n",
    "            else:\n",
    "                # If the file contains JSON but not a list, return empty list and warn.\n",
    "                print(f\"Warning: File '{json_file_path}' does not contain a list. \"\n",
    "                      \"Returning an empty list.\")\n",
    "                return []\n",
    "    except json.JSONDecodeError as e:\n",
    "        # Malformed JSON\n",
    "        print(f\"Warning: Could not decode JSON from '{json_file_path}'. Error: {e}. \"\n",
    "              \"Returning an empty list.\")\n",
    "        return []\n",
    "    except IOError as e:\n",
    "        # Other file reading issues\n",
    "        raise IOError(f\"Could not read from file '{json_file_path}': {e}\")\n",
    "    \n",
    "def analyze_code_generation_attempts(data_list):\n",
    "    summary_counts = {}\n",
    "    finished_rounds_per_entry = []\n",
    "\n",
    "    for entry in data_list:\n",
    "        found_finish_round_for_this_entry = False\n",
    "        \n",
    "        # Extract iteration numbers (keys like '0', '1', ...) and sort them numerically\n",
    "        iteration_numbers = []\n",
    "        for key_str in entry.keys():\n",
    "            if key_str.isdigit(): # Ensure the key is a string representing an integer\n",
    "                iteration_numbers.append(int(key_str))\n",
    "        iteration_numbers.sort() # Sorts numerically, e.g., [0, 1, 2, ..., 10, 11]\n",
    "\n",
    "        for iter_num in iteration_numbers:\n",
    "            iter_key_str = str(iter_num) # Convert back to string to use as dict key\n",
    "            \n",
    "            # Check if the expected nested structure and 'pass_rate' key exist\n",
    "            if iter_key_str in entry and \\\n",
    "               isinstance(entry[iter_key_str], dict) and \\\n",
    "               'code_0' in entry[iter_key_str] and \\\n",
    "               isinstance(entry[iter_key_str]['code_0'], dict) and \\\n",
    "               'pass_rate' in entry[iter_key_str]['code_0']:\n",
    "                \n",
    "                pass_rate = entry[iter_key_str]['code_0']['pass_rate']\n",
    "                \n",
    "                # Check if the code is runnable\n",
    "                if pass_rate == 1.0:\n",
    "                    # Iteration '0' corresponds to round 1, '1' to round 2, etc.\n",
    "                    finish_round_number = iter_num + 1\n",
    "                    finished_rounds_per_entry.append(finish_round_number)\n",
    "                    \n",
    "                    summary_key = f\"finished at round {finish_round_number}\"\n",
    "                    summary_counts[summary_key] = summary_counts.get(summary_key, 0) + 1\n",
    "                    \n",
    "                    found_finish_round_for_this_entry = True\n",
    "                    break # Found the first successful iteration, move to the next entry\n",
    "            # If structure is not as expected or pass_rate is not 1.0, continue to next iteration\n",
    "            \n",
    "        if not found_finish_round_for_this_entry:\n",
    "            # If no iteration achieved pass_rate = 1.0 for this entry\n",
    "            finished_rounds_per_entry.append(\"not finished\")\n",
    "            summary_key = \"not finished at all\"\n",
    "            summary_counts[summary_key] = summary_counts.get(summary_key, 0) + 1\n",
    "            \n",
    "    return summary_counts, finished_rounds_per_entry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8b9b0644",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_function_temp = \"\"\"import traceback\n",
    "def test_function(func):\n",
    "    try:\n",
    "        func()\n",
    "        return True, \"code is runnable\"\n",
    "    except Exception as e:\n",
    "        error_message_with_traceback = traceback.format_exc()\n",
    "        return False, \"code is not runnable, error: \" + str(error_message_with_traceback)\n",
    "\"\"\"\n",
    "# print(test_function_temp)\n",
    "forced_test_case = {\"runnability_check\":{'test_type': 'correctness', 'purpose': 'check if the code is runnable', 'test_function': test_function_temp, 'all_pass_times': 0, 'all_fail_times': 0}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cc7eacb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating codes:   0%|          | 0/10 [00:00<?, ?it/s]\n"
     ]
    }
   ],
   "source": [
    "test_task_description = \"\"\"Develop an image segmentation model to accurately and efficiently identify and segment FLAIR abnormalities in brain MR images from the LGG Segmentation Dataset.\n",
    "\n",
    "Dataset Format & Structure:\n",
    "\n",
    "Base Path: r\"E:\\python_project_new\\brain_case_study\\lgg-mri-segmentation\\case_study\"\n",
    "\n",
    "Organization: \n",
    "The base path contains train/, val/, and test/ subfolders.\n",
    "Each of these subfolders contains two further subfolders: images/ and masks/.\n",
    "\n",
    "Image Files (images/):\n",
    "Format: 3-channel .tif files.\n",
    "Naming: TCGA_<institution-code>_<patient-id>_<slice-number>.tif.\n",
    "\n",
    "Mask Files (masks/):\n",
    "Format: Binary, 1-channel .tif files.\n",
    "Naming: TCGA_<institution-code>_<patient-id>_<slice-number>_mask.tif.\n",
    "\n",
    "Create the code directly and The script should not run the main training or evaluation logic directly when the script file is executed. Instead, it should define all necessary functions with a main function main() that can be run WITHOUT ANY input, the main function need to load data, train the model, and evaluate model, and then return the model and Dice score on test data. DO NOT USE tensorflow. Use provided data path in your code.\n",
    "\n",
    "And then, write another function to plot some prediction result in test dataset\"\"\"\n",
    "\n",
    "lcdp = LCDP(api_key=None, model=\"gpt-4.1\", max_workers=1, ignore_advice=True, use_pr_predictor=False, use_web_search=False)\n",
    "\n",
    "# for i in range(5):\n",
    "#     best_codes = await lcdp.run(\n",
    "#         task_description=test_task_description,\n",
    "#         max_iterations=5,\n",
    "#         num_plans=3,\n",
    "#         num_tests=1,\n",
    "#         num_codes=1,\n",
    "#         refine_rounds=3,\n",
    "#         use_pass_rate_for_train=False,\n",
    "#         test_timeout=100,\n",
    "#         # use_example=True,\n",
    "#         # example_dataset=example_codes,\n",
    "#         use_async_generation=False,\n",
    "#         min_tests = 1,\n",
    "#         max_tests = 1,\n",
    "#         ai4s = True,\n",
    "#         use_llm_for_refine=False,\n",
    "#         # use_data_format_extract=[\"beach_profile_data/processed_data/beachdata_train.xlsx\", \"beach_profile_data/processed_data/beachdata_test.xlsx\"],\n",
    "#         record_all_results=True,\n",
    "#         forced_test_cases=forced_test_case,\n",
    "#     )\n",
    "#     # print(best_codes)\n",
    "#     # Save the best code dict into a file\n",
    "#     append_dict_to_json_list_file(best_codes, \"best_code.json\")\n",
    "\n",
    "best_codes = await lcdp.run(\n",
    "    task_description=test_task_description,\n",
    "    max_iterations=5,\n",
    "    num_plans=3,\n",
    "    num_tests=1,\n",
    "    num_codes=1,\n",
    "    refine_rounds=3,\n",
    "    use_pass_rate_for_train=False,\n",
    "    test_timeout=100,\n",
    "    # use_example=True,\n",
    "    # example_dataset=example_codes,\n",
    "    use_async_generation=False,\n",
    "    min_tests = 1,\n",
    "    max_tests = 1,\n",
    "    ai4s = True,\n",
    "    use_llm_for_refine=False,\n",
    "    # use_data_format_extract=[\"beach_profile_data/processed_data/beachdata_train.xlsx\", \"beach_profile_data/processed_data/beachdata_test.xlsx\"],\n",
    "    record_all_results=True,\n",
    "    forced_test_cases=forced_test_case,\n",
    "    prompt_only=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "329345c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Role ===\n",
      "You are a highly skilled coding assistant designed to generate clear, efficient, and correct code based on structured task descriptions and detailed plans provided by the user. Your responses must precisely follow the instructions, formats, and constraints given by the user, and you must strictly adhere to input-output formats, workflows, and specific guidelines outlined.\n",
      "\n",
      "=== Task Description ===\n",
      "Develop an image segmentation model to accurately and efficiently identify and segment FLAIR abnormalities in brain MR images from the LGG Segmentation Dataset.\n",
      "\n",
      "Dataset Format & Structure:\n",
      "\n",
      "Base Path: r\"E:\\python_project_nerain_case_study\\lgg-mri-segmentation\\case_study\"\n",
      "\n",
      "Organization: \n",
      "The base path contains train/, val/, and test/ subfolders.\n",
      "Each of these subfolders contains two further subfolders: images/ and masks/.\n",
      "\n",
      "Image Files (images/):\n",
      "Format: 3-channel .tif files.\n",
      "Naming: TCGA_<institution-code>_<patient-id>_<slice-number>.tif.\n",
      "\n",
      "Mask Files (masks/):\n",
      "Format: Binary, 1-channel .tif files.\n",
      "Naming: TCGA_<institution-code>_<patient-id>_<slice-number>_mask.tif.\n",
      "\n",
      "Create the code directly and The script should not run the main training or evaluation logic directly when the script file is executed. Instead, it should define all necessary functions with a main function main() that can be run WITHOUT ANY input, the main function need to load data, train the model, and evaluate model, and then return the model and Dice score on test data. DO NOT USE tensorflow. Use provided data path in your code.\n",
      "\n",
      "And then, write another function to plot some prediction result in test dataset\n",
      "\n",
      "=== Extra information ===\n",
      "\n",
      "- **Image Segmentation Task:** This is a pixel-wise classification problem where the goal is to delineate FLAIR abnormalities (typically regions of abnormal high signal intensity in the brain MRIs) using annotated binary masks.\n",
      "- **Evaluation Metric:** The Dice coefficient (or Dice Similarity Coefficient, DSC) is standard for segmentation. It is given by:  \n",
      "  \\[\n",
      "  Dice = \\frac{2|P \\cap G|}{|P| + |G|}\n",
      "  \\]\n",
      "  where \\(P\\) is the predicted mask, and \\(G\\) is the ground truth mask.\n",
      "- **Domain Knowledge:** FLAIR stands for Fluid-Attenuated Inversion Recovery, an MRI sequence sensitive to certain brain lesions such as tumors or edema. LGG refers to Low Grade Glioma, a type of brain tumor frequently segmented for clinical/research purposes.\n",
      "- **Model Architecture:** U-Net (and derivatives) are the de-facto standard for biomedical image segmentation, especially when data is limited.\n",
      "- **Loss Function:** Combination of Dice loss and Binary Cross-Entropy is common, as it helps address class imbalance between foreground (abnormality) and background.\n",
      "\n",
      "=== Feature analysis ===\n",
      "\n",
      "- **Images:** Each sample is a 3-channel (most likely RGB or pseudo-color) .tif image of a brain MRI slice. Each image is a 2D slice and the 3 channels may contain identical or different MRI information depending on the preprocessing, but typically for FLAIR images they are grayscale replicated to three channels.\n",
      "- **Masks:** Each mask is a 1-channel binary .tif image with pixel values 0 (background) and 1 (region of FLAIR abnormality/tumor).\n",
      "- **File Naming:** Each image and its corresponding mask share a common basename (up to the slice number). This allows for programmatic pairing.\n",
      "- **Data Splits:** Data is organized into train, val, and test splits, each with their own images/ and masks/ subfolders, which is crucial for reproducibility and fair evaluation.\n",
      "\n",
      "=== Extra Advice ===\n",
      "\n",
      "- **Data Preprocessing:** Normalize image intensity values (min-max or z-score normalization). Ensure all images are resized/cropped to the same shape (commonly 256x256 or as appropriate for U-Net input). If channel information is redundant, you could reduce it to one channel for efficiency.\n",
      "- **Label Preparation:** Ensure mask values are binary (0 and 1). Double-check that mask and image align correctly for every slice.\n",
      "- **Data Augmentation:** To combat overfitting and improve generalization, apply transformations such as flipping, rotation, scaling, and intensity shifts, ensuring they are applied equally to images and masks.\n",
      "- **Model Choice:** Use U-Net (or a lightweight modification thereof) due to its strong performance in medical image segmentation with relatively small datasets.\n",
      "- **Loss Function:** Use a combination of Binary Cross Entropy and Dice Loss to penalize both pixel-wise errors and poor overlap.\n",
      "- **Batching:** Due to likely high resolution and small dataset, use small batch sizes and possibly gradient accumulation.\n",
      "- **Evaluation:** Report Dice on the test set (already specified). Consider visualizing a few results for qualitative assessment, e.g., overlay masks or compare ground truth to prediction.\n",
      "- **Visualization:** Save (or plot) a grid of test images, ground truth, and predictions for several slices to inspect the model’s performance.\n",
      "\n",
      "By thoroughly preparing the dataset, using appropriate augmentations, and selecting a suitable architecture and losses, you can build a robust FLAIR abnormality segmentation model for the LGG dataset. Ensure reproducibility by setting random seeds and keeping all data splits strictly separate during model development and evaluation.\n",
      "\n",
      "=== Components ===\n",
      "\n",
      "**Component: load_lgg_dataset**\n",
      "Step Task Description: Load and preprocess the LGG MRI segmentation dataset from the specified directory. This involves pairing images and masks, normalizing the images, resizing images and masks to a consistent shape, binarizing masks, and preparing PyTorch datasets for train, val, and test splits.\n",
      "Input Format:\n",
      "- Argument 1: str with no fixed shape\n",
      "- Argument 2: tuple with shape=[2]\n",
      "Output Format:\n",
      "- Output 1: torch.utils.data.Dataset with no fixed shape\n",
      "- Output 2: torch.utils.data.Dataset with no fixed shape\n",
      "- Output 3: torch.utils.data.Dataset with no fixed shape\n",
      "Workflow Steps:\n",
      "- Iterate through the train, val, and test subfolders within the dataset base path.\n",
      "- For each split, gather all image-mask pairs by matching filenames.\n",
      "- Read each .tif image as a 3-channel ndarray, normalize pixel values to [0,1] using min-max normalization, and resize to the given size (e.g., (256, 256)).\n",
      "- Read each mask as a 1-channel ndarray, binarize to {0,1}, and resize/cast to the given size. Ensure masks align with images.\n",
      "- Create a custom torch.utils.data.Dataset class that, on __getitem__, returns normalized image tensor and corresponding mask tensor.\n",
      "- Return train, val, and test dataset objects.\n",
      "\n",
      "**Component: augmentation_transform**\n",
      "Step Task Description: Define augmentation transformations to be applied to image/mask pairs during training such as random horizontal/vertical flip, rotation, scaling, and intensity shifts. Ensure that identical transforms are applied to both image and corresponding mask.\n",
      "Input Format:\n",
      "- Argument 1: torch.Tensor with no fixed shape\n",
      "- Argument 2: torch.Tensor with no fixed shape\n",
      "- Argument 3: bool with no fixed shape\n",
      "Output Format:\n",
      "- Output 1: torch.Tensor with no fixed shape\n",
      "- Output 2: torch.Tensor with no fixed shape\n",
      "Workflow Steps:\n",
      "- If in training mode, randomly apply each augmentation (flip, rotate, scale, intensity shift) with predefined probabilities.\n",
      "- Apply the same spatial transformations to both the image and the mask to ensure alignment.\n",
      "- Return the transformed image and mask.\n",
      "\n",
      "**Component: unet_model**\n",
      "Step Task Description: Implement a configurable U-Net architecture suitable for 2D binary segmentation in PyTorch.\n",
      "Input Format:\n",
      "- Argument 1: int with no fixed shape\n",
      "- Argument 2: int with no fixed shape\n",
      "- Argument 3: int with no fixed shape\n",
      "Output Format:\n",
      "- Output 1: torch.nn.Module with no fixed shape\n",
      "Workflow Steps:\n",
      "- Define encoder and decoder blocks with skip connections.\n",
      "- Ensure input channels (e.g., 3), output channels (1), and intermediate features are configurable.\n",
      "- Initialize and return the complete model.\n",
      "\n",
      "**Component: dice_loss**\n",
      "Step Task Description: Implement the Dice loss function for binary image segmentation tasks in PyTorch.\n",
      "Input Format:\n",
      "- Argument 1: torch.Tensor with no fixed shape\n",
      "- Argument 2: torch.Tensor with no fixed shape\n",
      "Output Format:\n",
      "- Output 1: torch.Tensor with no fixed shape\n",
      "Workflow Steps:\n",
      "- Flatten prediction and mask tensors.\n",
      "- Compute Dice coefficient as 2*|P&Y|/(|P|+|Y|).\n",
      "- Return 1 - Dice coefficient as loss.\n",
      "\n",
      "**Component: bce_dice_loss**\n",
      "Step Task Description: Combine binary cross-entropy loss and Dice loss for better segmentation robustness.\n",
      "Input Format:\n",
      "- Argument 1: torch.Tensor with no fixed shape\n",
      "- Argument 2: torch.Tensor with no fixed shape\n",
      "Output Format:\n",
      "- Output 1: torch.Tensor with no fixed shape\n",
      "Workflow Steps:\n",
      "- Compute BCE loss between predicted mask and ground truth.\n",
      "- Compute Dice loss using the custom dice_loss function.\n",
      "- Sum the two losses (optionally with weights) and return.\n",
      "\n",
      "**Component: train_one_epoch**\n",
      "Step Task Description: Train the model for one epoch on the training data loader using specified loss function, optimizer, and optional scheduler.\n",
      "Input Format:\n",
      "- Argument 1: torch.nn.Module with no fixed shape\n",
      "- Argument 2: torch.utils.data.DataLoader with no fixed shape\n",
      "- Argument 3: torch.nn.Module with no fixed shape\n",
      "- Argument 4: torch.optim.Optimizer with no fixed shape\n",
      "- Argument 5: callable with no fixed shape\n",
      "- Argument 6: torch.device with no fixed shape\n",
      "Output Format:\n",
      "- Output 1: float with no fixed shape\n",
      "Workflow Steps:\n",
      "- Set model to train mode.\n",
      "- Iterate over DataLoader, move images/masks to device.\n",
      "- Apply forward pass, compute loss, backward, and optimizer step.\n",
      "- Optionally update scheduler.\n",
      "- Track and return average loss for the epoch.\n",
      "\n",
      "**Component: evaluate_model**\n",
      "Step Task Description: Evaluate the model on a given data loader, computing average Dice coefficient and loss.\n",
      "Input Format:\n",
      "- Argument 1: torch.nn.Module with no fixed shape\n",
      "- Argument 2: torch.utils.data.DataLoader with no fixed shape\n",
      "- Argument 3: callable with no fixed shape\n",
      "- Argument 4: torch.device with no fixed shape\n",
      "Output Format:\n",
      "- Output 1: float with no fixed shape\n",
      "- Output 2: float with no fixed shape\n",
      "Workflow Steps:\n",
      "- Set model to evaluation mode.\n",
      "- Iterate over the validation/test DataLoader.\n",
      "- Compute predicted mask, calculate Dice and loss per batch.\n",
      "- Average metrics over the entire set.\n",
      "\n",
      "**Component: train_model**\n",
      "Step Task Description: Train the U-Net model across multiple epochs using train and validation datasets, tracking the best validation dice score and returning the best model.\n",
      "Input Format:\n",
      "- Argument 1: torch.nn.Module with no fixed shape\n",
      "- Argument 2: torch.utils.data.DataLoader with no fixed shape\n",
      "- Argument 3: torch.utils.data.DataLoader with no fixed shape\n",
      "- Argument 4: callable with no fixed shape\n",
      "- Argument 5: int with no fixed shape\n",
      "- Argument 6: torch.device with no fixed shape\n",
      "Output Format:\n",
      "- Output 1: torch.nn.Module with no fixed shape\n",
      "- Output 2: float with no fixed shape\n",
      "Workflow Steps:\n",
      "- Loop for a set number of epochs.\n",
      "- At each epoch, call train_one_epoch, then evaluate_model on validation data.\n",
      "- Track and keep the best scoring model (by Dice).\n",
      "- Return the best model and best Dice score observed.\n",
      "\n",
      "**Component: predict_on_dataset**\n",
      "Step Task Description: Generate predictions for each sample in a DataLoader using the trained model, collecting images, ground-truth masks, and predicted masks for later analysis or visualization.\n",
      "Input Format:\n",
      "- Argument 1: torch.nn.Module with no fixed shape\n",
      "- Argument 2: torch.utils.data.DataLoader with no fixed shape\n",
      "- Argument 3: torch.device with no fixed shape\n",
      "Output Format:\n",
      "- Output 1: list with no fixed shape\n",
      "- Output 2: list with no fixed shape\n",
      "- Output 3: list with no fixed shape\n",
      "Workflow Steps:\n",
      "- Set model to eval; for each batch, move to device.\n",
      "- Compute predicted mask, binarize outputs.\n",
      "- Collect input images, ground truth masks, and predictions into lists.\n",
      "- Return all for further visualization or analysis.\n",
      "\n",
      "**Component: plot_predictions_grid**\n",
      "Step Task Description: Create a visual grid comparing original images, ground-truth masks, and predicted masks for several examples in the test set.\n",
      "Input Format:\n",
      "- Argument 1: list with no fixed shape\n",
      "- Argument 2: list with no fixed shape\n",
      "- Argument 3: list with no fixed shape\n",
      "- Argument 4: int with no fixed shape\n",
      "Output Format:\n",
      "\n",
      "Workflow Steps:\n",
      "- Select N random indices from dataset.\n",
      "- For each selected sample, plot the image, ground truth mask, and predicted mask side-by-side.\n",
      "- Optionally overlay masks on image.\n",
      "- Display result with clear titles/labels for visual assessment.\n",
      "\n",
      "**Component: main**\n",
      "Step Task Description: High-level orchestration function that loads the data, creates model and augmentations, trains the model, evaluates on test set, and returns the trained model and test set Dice score.\n",
      "Input Format:\n",
      "\n",
      "Output Format:\n",
      "- Output 1: torch.nn.Module with no fixed shape\n",
      "- Output 2: float with no fixed shape\n",
      "Workflow Steps:\n",
      "- Set random seeds and select computation device.\n",
      "- Load and preprocess datasets using load_lgg_dataset.\n",
      "- Wrap datasets with DataLoaders and apply augmentation transforms for training.\n",
      "- Instantiate U-Net architecture.\n",
      "- Select optimizer, define loss function (bce_dice_loss), and set training config parameters.\n",
      "- Call train_model using train and validation sets.\n",
      "- Evaluate best model on test set using evaluate_model.\n",
      "- Return best model and the test set Dice score.\n",
      "\n",
      "=== Overall Plan ===\n",
      "Input Format:\n",
      "\n",
      "Output Format:\n",
      "- Output 1: torch.nn.Module with no fixed shape\n",
      "- Output 2: float with no fixed shape\n",
      "Components Order: load_lgg_dataset, augmentation_transform, unet_model, dice_loss, bce_dice_loss, train_one_epoch, evaluate_model, train_model, predict_on_dataset, plot_predictions_grid, main\n",
      "Plan Steps:\n",
      "- Call main(), which sets up device and random seeds for reproducibility.\n",
      "- Use load_lgg_dataset with the provided data base path and image size to prepare PyTorch Dataset objects for train/val/test.\n",
      "- Wrap train/val/test datasets into DataLoader objects; attach augmentation_transform to the training dataset for data augmentation.\n",
      "- Create a U-Net model instance via unet_model, specifying input/output channels.\n",
      "- Define a suitable optimizer (e.g., Adam), and combine dice_loss and BCE using bce_dice_loss.\n",
      "- Call train_model with the model, DataLoaders, loss function, epoch count, and device. This will internally use train_one_epoch for training and evaluate_model for validation.\n",
      "- After training, use evaluate_model on the test data to compute Dice and average loss.\n",
      "- Return best model and Dice score.\n",
      "- Separately, use predict_on_dataset to gather images, masks, and predictions from test set. Use plot_predictions_grid to visualize and compare true and predicted segmentations.\n",
      "\n",
      "=== Instructions ===\n",
      "Generate the COMPLETE code based on the components and plan above.\n",
      "DO MAKE SURE the complete code is a runnable function, all components are correctly integrated with in this function.\n",
      "The complete function should take the input arguments as specified in the overall plan and return the output as specified.\n",
      "Please add as much comments as possible to your code to explain the logic and any critical steps.\n",
      "Structure your response as follows:\n",
      "<Code>\n",
      "Your code here. DO make sure the output is a single function that integrates all components.\n",
      "</Code>\n",
      "<Planning>\n",
      "A detailed step-by-step explanation of the code's workflow.\n",
      "</Planning>\n",
      "<Main Function Name>\n",
      "The name of the main function that integrates all components.\n",
      "</Main Function Name>\n",
      "Provide the reasoning, code, planning, function name with the SAME indicator and structure as shown in Instructions. DO NOT return any test cases or example usages in your code!\n"
     ]
    }
   ],
   "source": [
    "print(best_codes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8126c6f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "result_file_load = load_dicts_from_json_file(\"best_code.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "19698d4d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "({'finished at round 1': 6, 'finished at round 2': 4, 'finished at round 3': 5, 'not finished at all': 5}, [1, 1, 2, 3, 'not finished', 2, 1, 2, 3, 1, 3, 1, 'not finished', 2, 'not finished', 3, 1, 'not finished', 3, 'not finished'])\n"
     ]
    }
   ],
   "source": [
    "output_temp = analyze_code_generation_attempts(result_file_load)\n",
    "print(output_temp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "72354229",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['0', '1', '2'])\n"
     ]
    }
   ],
   "source": [
    "codes_dict = result_file_load[18]\n",
    "print(codes_dict.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "04d59290",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "import os\n",
      "import numpy as np\n",
      "import torch\n",
      "import torch.nn as nn\n",
      "import torch.nn.functional as F\n",
      "from torch.utils.data import Dataset, DataLoader\n",
      "from PIL import Image\n",
      "import random\n",
      "import copy\n",
      "import matplotlib.pyplot as plt\n",
      "\n",
      "# ------- Dataset Class --------\n",
      "class BrainMRISegmentationDataset(Dataset):\n",
      "    \"\"\"\n",
      "    Custom PyTorch Dataset for loading pairs of brain MRI slices and corresponding masks.\n",
      "    Handles per-image normalization, resizing, and (optionally) augmentation.\n",
      "    \"\"\"\n",
      "    def __init__(\n",
      "        self,\n",
      "        pairs,\n",
      "        image_size=128,\n",
      "        is_train=False,\n",
      "        augment_prob=0.5\n",
      "    ):\n",
      "        self.pairs = pairs\n",
      "        self.image_size = image_size\n",
      "        self.is_train = is_train\n",
      "        self.augment_prob = augment_prob\n",
      "\n",
      "    def __len__(self):\n",
      "        return len(self.pairs)\n",
      "\n",
      "    def __getitem__(self, idx):\n",
      "        img_path, mask_path = self.pairs[idx]\n",
      "        # Load image (as RGB)\n",
      "        image = Image.open(img_path).convert('RGB')\n",
      "        # Load mask (as single channel)\n",
      "        mask = Image.open(mask_path).convert('L')\n",
      "\n",
      "        # Convert to numpy\n",
      "        image = np.array(image).astype(np.float32)\n",
      "        mask = np.array(mask).astype(np.uint8)\n",
      "\n",
      "        # Ensure binary mask [0,1]\n",
      "        mask = (mask > 0).astype(np.float32)\n",
      "\n",
      "        # Per-image normalization (z-score for all channels)\n",
      "        img_mean = image.mean()\n",
      "        img_std = image.std()\n",
      "        if img_std < 1e-6: img_std = 1.0\n",
      "        image = (image - img_mean) / img_std\n",
      "\n",
      "        # Data augmentation for training\n",
      "        if self.is_train:\n",
      "            # Horizontal flip\n",
      "            if random.random() < self.augment_prob:\n",
      "                image = np.flip(image, axis=1)\n",
      "                mask = np.flip(mask, axis=1)\n",
      "            # Vertical flip\n",
      "            if random.random() < self.augment_prob:\n",
      "                image = np.flip(image, axis=0)\n",
      "                mask = np.flip(mask, axis=0)\n",
      "            # Random rotation\n",
      "            if random.random() < self.augment_prob:\n",
      "                k = random.choice([1, 2, 3])  # 90,180,270 degrees\n",
      "                image = np.rot90(image, k, axes=(0, 1))\n",
      "                mask = np.rot90(mask, k, axes=(0, 1))\n",
      "\n",
      "        # Resize: for images, use bilinear; for masks, use nearest\n",
      "        image = Image.fromarray(\n",
      "            np.clip(\n",
      "                (((image - image.min()) / (image.max() - image.min() + 1e-8)) * 255),\n",
      "                0, 255).astype(np.uint8)\n",
      "        ).resize((self.image_size, self.image_size), Image.BILINEAR)\n",
      "        image = np.array(image).astype(np.float32)\n",
      "        # After resize: normalization again (for safety)\n",
      "        img_mean = image.mean()\n",
      "        img_std = image.std()\n",
      "        image = (image - img_mean) / (img_std + 1e-8)\n",
      "\n",
      "        mask = Image.fromarray((mask * 255).astype(np.uint8))\n",
      "        mask = mask.resize((self.image_size, self.image_size), Image.NEAREST)\n",
      "        mask = np.array(mask).astype(np.float32)\n",
      "        mask = (mask > 127).astype(np.float32)\n",
      "\n",
      "        # HWC -> CHW\n",
      "        if image.ndim == 2:\n",
      "            image = np.stack([image, image, image], axis=2)\n",
      "        image = np.transpose(image, (2, 0, 1))  # (3,H,W)\n",
      "        mask = mask[None, :, :]  # (1,H,W)\n",
      "\n",
      "        return torch.from_numpy(image).float(), torch.from_numpy(mask).float()\n",
      "\n",
      "def create_pairs(images_dir, masks_dir, max_samples=None):\n",
      "    \"\"\"\n",
      "    Get list of (image_path, mask_path) pairs. Limit to max_samples if set.\n",
      "    \"\"\"\n",
      "    image_files = [f for f in os.listdir(images_dir) if f.endswith('.tif')]\n",
      "    pairs = []\n",
      "    for fname in image_files:\n",
      "        base = fname[:-4]\n",
      "        mask_fname = f\"{base}_mask.tif\"\n",
      "        img_path = os.path.join(images_dir, fname)\n",
      "        mask_path = os.path.join(masks_dir, mask_fname)\n",
      "        if os.path.exists(mask_path):\n",
      "            pairs.append((img_path, mask_path))\n",
      "    # Optionally shuffle and subsample\n",
      "    random.shuffle(pairs)\n",
      "    if max_samples is not None:\n",
      "        pairs = pairs[:max_samples]\n",
      "    return pairs\n",
      "\n",
      "def load_datasets(\n",
      "    base_path, image_size=128, batch_size=4, num_workers=0, max_train=100, max_val=20, max_test=20\n",
      "):\n",
      "    \"\"\"\n",
      "    Load datasets and DataLoaders for train/val/test sets.\n",
      "    \"\"\"\n",
      "    splits = ['train', 'val', 'test']\n",
      "    max_by_split = {'train': max_train, 'val': max_val, 'test': max_test}\n",
      "    dataset_dict = {}\n",
      "    dataloader_dict = {}\n",
      "    for split in splits:\n",
      "        img_dir = os.path.join(base_path, split, 'images')\n",
      "        mask_dir = os.path.join(base_path, split, 'masks')\n",
      "        max_samples = max_by_split[split]\n",
      "        pairs = create_pairs(img_dir, mask_dir, max_samples=max_samples)\n",
      "        is_train = (split == 'train')\n",
      "        dataset = BrainMRISegmentationDataset(pairs, image_size=image_size, is_train=is_train)\n",
      "        shuffle = is_train\n",
      "        loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=False)\n",
      "        dataset_dict[split] = dataset\n",
      "        dataloader_dict[split] = loader\n",
      "    return dataset_dict, dataloader_dict\n",
      "\n",
      "# ------- UNet model (Lightweight) --------\n",
      "class DoubleConv(nn.Module):\n",
      "    \"\"\"(Conv2d -> BN -> ReLU) * 2\"\"\"\n",
      "    def __init__(self, in_ch, out_ch):\n",
      "        super().__init__()\n",
      "        self.double_conv = nn.Sequential(\n",
      "            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),\n",
      "            nn.BatchNorm2d(out_ch),\n",
      "            nn.ReLU(inplace=True),\n",
      "            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),\n",
      "            nn.BatchNorm2d(out_ch),\n",
      "            nn.ReLU(inplace=True)\n",
      "        )\n",
      "    def forward(self, x):\n",
      "        return self.double_conv(x)\n",
      "\n",
      "class UNetLite(nn.Module):\n",
      "    \"\"\"Lightweight U-Net for 2D segmentation\"\"\"\n",
      "    def __init__(self, in_ch=3, out_ch=1):\n",
      "        super().__init__()\n",
      "        chs = [16, 32, 64, 128, 256]\n",
      "        # Encoder\n",
      "        self.inc = DoubleConv(in_ch, chs[0])\n",
      "        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(chs[0], chs[1]))\n",
      "        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(chs[1], chs[2]))\n",
      "        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(chs[2], chs[3]))\n",
      "        self.down4 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(chs[3], chs[4]))\n",
      "        # Decoder\n",
      "        self.up1 = nn.ConvTranspose2d(chs[4], chs[3], kernel_size=2, stride=2)\n",
      "        self.conv1 = DoubleConv(chs[4], chs[3])\n",
      "        self.up2 = nn.ConvTranspose2d(chs[3], chs[2], kernel_size=2, stride=2)\n",
      "        self.conv2 = DoubleConv(chs[3], chs[2])\n",
      "        self.up3 = nn.ConvTranspose2d(chs[2], chs[1], kernel_size=2, stride=2)\n",
      "        self.conv3 = DoubleConv(chs[2], chs[1])\n",
      "        self.up4 = nn.ConvTranspose2d(chs[1], chs[0], kernel_size=2, stride=2)\n",
      "        self.conv4 = DoubleConv(chs[1], chs[0])\n",
      "        self.outc = nn.Conv2d(chs[0], out_ch, kernel_size=1)\n",
      "    def forward(self, x):\n",
      "        x1 = self.inc(x)\n",
      "        x2 = self.down1(x1)\n",
      "        x3 = self.down2(x2)\n",
      "        x4 = self.down3(x3)\n",
      "        x5 = self.down4(x4)\n",
      "        u1 = self.up1(x5)\n",
      "        u1 = torch.cat([u1, x4], dim=1)\n",
      "        u1 = self.conv1(u1)\n",
      "        u2 = self.up2(u1)\n",
      "        u2 = torch.cat([u2, x3], dim=1)\n",
      "        u2 = self.conv2(u2)\n",
      "        u3 = self.up3(u2)\n",
      "        u3 = torch.cat([u3, x2], dim=1)\n",
      "        u3 = self.conv3(u3)\n",
      "        u4 = self.up4(u3)\n",
      "        u4 = torch.cat([u4, x1], dim=1)\n",
      "        u4 = self.conv4(u4)\n",
      "        return self.outc(u4)\n",
      "\n",
      "# ------- Dice Coefficient --------\n",
      "def dice_coefficient(pred_logits, target, threshold=0.5, eps=1e-7):\n",
      "    \"\"\"\n",
      "    pred_logits: model output, shape (B,1,H,W)\n",
      "    target: ground truth, shape (B,1,H,W)\n",
      "    Returns: average Dice over batch\n",
      "    \"\"\"\n",
      "    pred = torch.sigmoid(pred_logits)\n",
      "    pred = (pred > threshold).float()\n",
      "    target = (target > 0.5).float()\n",
      "    intersection = (pred * target).sum(dim=(2, 3))\n",
      "    unionset = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))\n",
      "    dice = (2 * intersection + eps) / (unionset + eps)\n",
      "    # If both mask and pred are all zero, count as perfect Dice = 1.0\n",
      "    zeros = ((target.sum(dim=(2,3))==0) & (pred.sum(dim=(2,3))==0))\n",
      "    dice = torch.where(zeros, torch.ones_like(dice), dice)\n",
      "    return dice.mean().item()\n",
      "\n",
      "class DiceLoss(nn.Module):\n",
      "    def __init__(self, eps=1e-7):\n",
      "        super().__init__()\n",
      "        self.eps = eps\n",
      "    def forward(self, logits, targets):\n",
      "        inputs = torch.sigmoid(logits)\n",
      "        targets = (targets > 0.5).float()\n",
      "        intersection = (inputs * targets).sum(dim=(2,3))\n",
      "        unionset = inputs.sum(dim=(2,3)) + targets.sum(dim=(2,3))\n",
      "        dice = (2 * intersection + self.eps) / (unionset + self.eps)\n",
      "        loss = 1 - dice\n",
      "        return loss.mean()\n",
      "\n",
      "# ------- Train and Evaluate --------\n",
      "def train_model(\n",
      "    model,\n",
      "    dataloader_dict,\n",
      "    device,\n",
      "    n_epochs=10,\n",
      "    lr=1e-3,\n",
      "    patience=2,\n",
      "):\n",
      "    optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
      "    bce_loss = nn.BCEWithLogitsLoss()\n",
      "    dice_loss = DiceLoss()\n",
      "    best_state = copy.deepcopy(model.state_dict())\n",
      "    best_val_dice = -np.inf\n",
      "    epochs_no_improve = 0\n",
      "    for ep in range(n_epochs):\n",
      "        model.train()\n",
      "        train_loss = []\n",
      "        train_dice = []\n",
      "        for x, y in dataloader_dict['train']:\n",
      "            x = x.to(device)\n",
      "            y = y.to(device)\n",
      "            optimizer.zero_grad()\n",
      "            pred = model(x)\n",
      "            loss = 0.5*bce_loss(pred, y) + 0.5*dice_loss(pred, y)\n",
      "            loss.backward()\n",
      "            optimizer.step()\n",
      "            train_loss.append(loss.item())\n",
      "            train_dice.append(dice_coefficient(pred, y))\n",
      "        # Validation\n",
      "        model.eval()\n",
      "        val_loss = []\n",
      "        val_dice = []\n",
      "        with torch.no_grad():\n",
      "            for x, y in dataloader_dict['val']:\n",
      "                x = x.to(device)\n",
      "                y = y.to(device)\n",
      "                pred = model(x)\n",
      "                loss = 0.5*bce_loss(pred, y) + 0.5*dice_loss(pred, y)\n",
      "                val_loss.append(loss.item())\n",
      "                val_dice.append(dice_coefficient(pred, y))\n",
      "        avg_val_dice = np.mean(val_dice)\n",
      "        # Early stopping\n",
      "        if avg_val_dice > best_val_dice + 1e-4:\n",
      "            best_val_dice = avg_val_dice\n",
      "            best_state = copy.deepcopy(model.state_dict())\n",
      "            epochs_no_improve = 0\n",
      "        else:\n",
      "            epochs_no_improve += 1\n",
      "        if epochs_no_improve >= patience:\n",
      "            break\n",
      "    model.load_state_dict(best_state)\n",
      "    return model\n",
      "\n",
      "def evaluate_model(model, dataloader, device):\n",
      "    \"\"\"\n",
      "    Return average Dice score on dataloader set.\n",
      "    \"\"\"\n",
      "    model.eval()\n",
      "    dices = []\n",
      "    with torch.no_grad():\n",
      "        for x, y in dataloader:\n",
      "            x = x.to(device)\n",
      "            y = y.to(device)\n",
      "            pred = model(x)\n",
      "            dice = dice_coefficient(pred, y)\n",
      "            dices.append(dice)\n",
      "    mean_dice = float(np.mean(dices)) if len(dices) > 0 else 0.0\n",
      "    return mean_dice\n",
      "\n",
      "# ------- Visualization ---------\n",
      "def plot_test_predictions(model, dataloader, device, n_samples=4):\n",
      "    \"\"\"\n",
      "    Show input/test image, ground truth, and model prediction side by side.\n",
      "    \"\"\"\n",
      "    model.eval()\n",
      "    shown = 0\n",
      "    with torch.no_grad():\n",
      "        for x, y in dataloader:\n",
      "            x = x.to(device)\n",
      "            y = y.to(device)\n",
      "            pred = model(x)\n",
      "            pred_probs = torch.sigmoid(pred)\n",
      "            pred_mask = (pred_probs > 0.5).float()\n",
      "            bs = x.shape[0]\n",
      "            for i in range(bs):\n",
      "                if shown >= n_samples:\n",
      "                    return\n",
      "                img_np = x[i].cpu().numpy()\n",
      "                img_np = img_np - img_np.min()\n",
      "                img_np = img_np / (img_np.max() + 1e-8)\n",
      "                img_np = np.transpose(img_np, (1,2,0))  # HWC\n",
      "                gt_mask = y[i,0].cpu().numpy()\n",
      "                pr_mask = pred_mask[i,0].cpu().numpy()\n",
      "                # Plot\n",
      "                fig, axs = plt.subplots(1,3, figsize=(12,4))\n",
      "                axs[0].imshow(img_np, cmap='gray')\n",
      "                axs[0].set_title('Input Image')\n",
      "                axs[0].axis('off')\n",
      "                axs[1].imshow(gt_mask, cmap='Reds')\n",
      "                axs[1].set_title('Ground truth')\n",
      "                axs[1].axis('off')\n",
      "                axs[2].imshow(img_np, cmap='gray')\n",
      "                axs[2].imshow(pr_mask, cmap='Blues', alpha=0.5)\n",
      "                axs[2].set_title('Prediction Overlay')\n",
      "                axs[2].axis('off')\n",
      "                plt.tight_layout()\n",
      "                plt.show()\n",
      "                shown += 1\n",
      "\n",
      "# ------------- MAIN PIPELINE ---------------\n",
      "def main():\n",
      "    \"\"\"\n",
      "    Loads data, trains segmentation model, returns model and average Dice on test set.\n",
      "    \"\"\"\n",
      "    # ---- Settings (tuned for quick/robust test) ----\n",
      "    DATA_PATH = r\"E:\\python_project_new\\brain_case_study\\lgg-mri-segmentation\\case_study\"\n",
      "    IMG_SIZE = 128        # Small input size for speed\n",
      "    BATCH_SIZE = 4\n",
      "    N_EPOCHS = 10\n",
      "    LR = 1e-3\n",
      "    PATIENCE = 2\n",
      "    MAX_TRAIN = 100      # Dataset size limits for speed\n",
      "    MAX_VAL = 20\n",
      "    MAX_TEST = 20\n",
      "    NUM_WORKERS = 0      # for compatibility\n",
      "    # ---- Device ----\n",
      "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
      "    # ---- Data load ----\n",
      "    dataset_dict, dataloader_dict = load_datasets(\n",
      "        DATA_PATH,\n",
      "        image_size=IMG_SIZE,\n",
      "        batch_size=BATCH_SIZE,\n",
      "        num_workers=NUM_WORKERS,\n",
      "        max_train=MAX_TRAIN,\n",
      "        max_val=MAX_VAL,\n",
      "        max_test=MAX_TEST\n",
      "    )\n",
      "    # ---- Model ----\n",
      "    model = UNetLite(in_ch=3, out_ch=1).to(device)\n",
      "    # ---- Train ----\n",
      "    model = train_model(\n",
      "        model,\n",
      "        dataloader_dict,\n",
      "        device,\n",
      "        n_epochs=N_EPOCHS,\n",
      "        lr=LR,\n",
      "        patience=PATIENCE\n",
      "    )\n",
      "    # ---- Evaluate test set ----\n",
      "    test_dice = evaluate_model(model, dataloader_dict['test'], device)\n",
      "    return model, test_dice\n",
      "\n",
      "# --------- Separate function for plotting ---------------\n",
      "def plot_test_results(model, dataloader_dict, device, n_samples=4):\n",
      "    \"\"\"\n",
      "    Plots segmentation results for a few test samples.\n",
      "    \"\"\"\n",
      "    plot_test_predictions(model, dataloader_dict['test'], device, n_samples=n_samples)\n"
     ]
    }
   ],
   "source": [
    "code_dict = codes_dict[\"2\"][\"code_0\"]\n",
    "code = code_dict[\"code\"]\n",
    "# code = code_dict[\"test_case_results\"]\n",
    "print(code)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5b687b9e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'runnability_check': {'test_type': 'correctness', 'purpose': 'check if the code is runnable', 'test_function': 'def test_function(func):\\n    try:\\n        func()\\n        return True, \"code is runnable\"\\n    except Exception as e:\\n        return False, \"code is not runnable, error: \" + str(e)\\n', 'all_pass_times': 0, 'all_fail_times': 2}}\n",
      "1\n",
      "Test ID: runnability_check\n",
      "Test Case:\n",
      "def test_function(func):\n",
      "    try:\n",
      "        func()\n",
      "        return True, \"code is runnable\"\n",
      "    except Exception as e:\n",
      "        return False, \"code is not runnable, error: \" + str(e)\n",
      "\n",
      "Test Type: correctness\n",
      "----------------------------------------\n"
     ]
    }
   ],
   "source": [
    "test_cases = lcdp.test_cases\n",
    "print(test_cases)\n",
    "print(len(test_cases))\n",
    "for test_id, test_case_dict in test_cases.items():\n",
    "    print(f\"Test ID: {test_id}\")\n",
    "    print(f\"Test Case:\\n{test_case_dict['test_function']}\")\n",
    "    print(f\"Test Type: {test_case_dict['test_type']}\")\n",
    "    print(\"-\" * 40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "473c5ed8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['code_0', 'code_1', 'code_2', 'code_3', 'code_4', 'code_5', 'code_6', 'code_7', 'code_8', 'code_9'])\n",
      "dict_keys(['code', 'plan', 'main_function_name', 'think'])\n",
      "Scores:\n",
      "\n",
      "Overall score:\n"
     ]
    },
    {
     "ename": "KeyError",
     "evalue": "'score'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[3], line 8\u001b[0m\n\u001b[0;32m      6\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mScores:\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m      7\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mOverall score:\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m----> 8\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mbest_codes\u001b[49m\u001b[43m[\u001b[49m\u001b[43mchecking_code\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mscore\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m)\n\u001b[0;32m      9\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mPass rate score:\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m     10\u001b[0m \u001b[38;5;28mprint\u001b[39m(best_codes[checking_code][\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpass_rate_score\u001b[39m\u001b[38;5;124m'\u001b[39m])\n",
      "\u001b[1;31mKeyError\u001b[0m: 'score'"
     ]
    }
   ],
   "source": [
    "checking_code = \"code_1\"\n",
    "print(best_codes.keys())\n",
    "print(best_codes[checking_code].keys())\n",
    "\n",
    "# print all scores  \n",
    "print(\"Scores:\")\n",
    "print(\"\\nOverall score:\")\n",
    "print(best_codes[checking_code]['score'])\n",
    "print(\"\\nPass rate score:\")\n",
    "print(best_codes[checking_code]['pass_rate_score'])\n",
    "print(\"\\nPrediction score:\")\n",
    "print(best_codes[checking_code]['prediction_score'])\n",
    "print(\"\\nPylint score:\")\n",
    "print(best_codes[checking_code]['pylint_score'])\n",
    "print(\"\\nRadon score:\")\n",
    "print(best_codes[checking_code]['radon_score'])\n",
    "print(\"\\nTest case results:\")\n",
    "print(best_codes[checking_code][\"test_case_results\"])\n",
    "print(sum([v[\"success\"] for v in best_codes[checking_code][\"test_case_results\"].values()]))\n",
    "print(len(best_codes[checking_code][\"test_case_results\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c78281b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "\n",
    "def append_dict_to_json_list_file(dictionary_to_append, json_file_path):\n",
    "    # Validate input type\n",
    "    if not isinstance(dictionary_to_append, dict):\n",
    "        raise TypeError(\"The data to append must be a dictionary.\")\n",
    "\n",
    "    data_list = []\n",
    "\n",
    "    # Check if the file exists and is not empty\n",
    "    if os.path.exists(json_file_path) and os.path.getsize(json_file_path) > 0:\n",
    "        try:\n",
    "            with open(json_file_path, 'r', encoding='utf-8') as f:\n",
    "                # Attempt to load existing data\n",
    "                existing_data = json.load(f)\n",
    "                # Ensure it's a list, otherwise start fresh (or raise error)\n",
    "                if isinstance(existing_data, list):\n",
    "                    data_list = existing_data\n",
    "                else:\n",
    "                    # If it's not a list, we'll overwrite with a new list.\n",
    "                    # Alternatively, you could raise an error here if strict format is required.\n",
    "                    print(f\"Warning: File '{json_file_path}' did not contain a list. \"\n",
    "                          \"It will be overwritten with a new list.\")\n",
    "                    data_list = [] \n",
    "        except json.JSONDecodeError as e:\n",
    "            # If JSON is malformed, print a warning and overwrite with a new list.\n",
    "            print(f\"Warning: File '{json_file_path}' contained invalid JSON. \"\n",
    "                  f\"It will be overwritten. Error: {e}\")\n",
    "            data_list = []\n",
    "        except IOError as e:\n",
    "            # Catch other potential file reading issues\n",
    "            raise IOError(f\"Could not read from file '{json_file_path}': {e}\")\n",
    "    elif os.path.exists(json_file_path) and os.path.getsize(json_file_path) == 0:\n",
    "        # File exists but is empty. Treat as a new list.\n",
    "        print(f\"Info: File '{json_file_path}' is empty. Initializing a new list.\")\n",
    "        data_list = []\n",
    "\n",
    "    # Append the new dictionary to the list\n",
    "    data_list.append(dictionary_to_append)\n",
    "\n",
    "    # Write the updated list back to the JSON file\n",
    "    try:\n",
    "        with open(json_file_path, 'w', encoding='utf-8') as f:\n",
    "            json.dump(data_list, f, indent=4, ensure_ascii=False)\n",
    "    except IOError as e:\n",
    "        # Catch potential file writing issues\n",
    "        raise IOError(f\"Could not write to file '{json_file_path}': {e}\")\n",
    "\n",
    "def load_dicts_from_json_file(json_file_path):\n",
    "    if not os.path.exists(json_file_path) or os.path.getsize(json_file_path) == 0:\n",
    "        # File doesn't exist or is empty\n",
    "        return []\n",
    "\n",
    "    try:\n",
    "        with open(json_file_path, 'r', encoding='utf-8') as f:\n",
    "            data = json.load(f)\n",
    "            if isinstance(data, list):\n",
    "                return data\n",
    "            else:\n",
    "                # If the file contains JSON but not a list, return empty list and warn.\n",
    "                print(f\"Warning: File '{json_file_path}' does not contain a list. \"\n",
    "                      \"Returning an empty list.\")\n",
    "                return []\n",
    "    except json.JSONDecodeError as e:\n",
    "        # Malformed JSON\n",
    "        print(f\"Warning: Could not decode JSON from '{json_file_path}'. Error: {e}. \"\n",
    "              \"Returning an empty list.\")\n",
    "        return []\n",
    "    except IOError as e:\n",
    "        # Other file reading issues\n",
    "        raise IOError(f\"Could not read from file '{json_file_path}': {e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "74ae6168",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "import pandas as pd\n",
      "import numpy as np\n",
      "from keras.models import Sequential\n",
      "from keras.layers import Dense, Conv1D, Flatten\n",
      "from keras.utils import to_categorical\n",
      "from keras.optimizers import Adam\n",
      "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
      "from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n",
      "\n",
      "def load_and_prepare_data(train_path, test_path):\n",
      "    # Load data from the Excel files\n",
      "    train_df = pd.read_excel(train_path)\n",
      "    test_df = pd.read_excel(test_path)\n",
      "    \n",
      "    # Handle missing values by dropping rows with NaN for simplicity\n",
      "    train_df.dropna(inplace=True)\n",
      "    test_df.dropna(inplace=True)\n",
      "    \n",
      "    # One-hot encode the 'Dominant Wave Direction' categorical variable\n",
      "    encoder = OneHotEncoder(sparse=False, handle_unknown='ignore')\n",
      "    train_encoded = encoder.fit_transform(train_df[['Dominant Wave Direction']])\n",
      "    test_encoded = encoder.transform(test_df[['Dominant Wave Direction']])\n",
      "    \n",
      "    # Add the one-hot encoded columns back to the DataFrames\n",
      "    train_encoded_df = pd.DataFrame(train_encoded, columns=encoder.get_feature_names_out(['Dominant Wave Direction']))\n",
      "    test_encoded_df = pd.DataFrame(test_encoded, columns=encoder.get_feature_names_out(['Dominant Wave Direction']))\n",
      "    train_df = train_df.join(train_encoded_df).drop(columns=['Dominant Wave Direction'])\n",
      "    test_df = test_df.join(test_encoded_df).drop(columns=['Dominant Wave Direction'])\n",
      "    \n",
      "    return train_df, test_df\n",
      "\n",
      "def calculate_bruun_model_feature(df):\n",
      "    # Extract 'Intertidal Slope' and 'Deep Water Wave Height Hd'\n",
      "    slope = df['Intertidal Slope']\n",
      "    sea_level_rise = df['Deep Water Wave Height Hd']  # Assuming Hd as proxy for sea level rise\n",
      "    \n",
      "    # Calculate shoreline retreat using Bruun's rule: y = S / tan(β)\n",
      "    df['Bruun_Shoreline_Retreat'] = sea_level_rise / np.tan(np.radians(slope))\n",
      "    \n",
      "    return df\n",
      "\n",
      "def calculate_dean_model_feature(df):\n",
      "    # Example calculation based on Dean's model principles\n",
      "    df['Dean_Wave_Sediment'] = df['Mean Grain Size'] * df['Mean Wave Height']\n",
      "    \n",
      "    return df\n",
      "\n",
      "def build_deep_learning_model(input_shape, output_shape):\n",
      "    # Create a Sequential model\n",
      "    model = Sequential()\n",
      "    # Add a Conv1D layer to handle spatial features\n",
      "    model.add(Conv1D(64, kernel_size=2, activation='relu', input_shape=(input_shape, 1)))\n",
      "    model.add(Flatten())\n",
      "    model.add(Dense(128, activation='relu'))\n",
      "    model.add(Dense(64, activation='relu'))\n",
      "    # Output layer for regression output (predicting elevation 'y')\n",
      "    model.add(Dense(output_shape, activation='linear'))\n",
      "    \n",
      "    # Compile model for regression (with mean squared error loss)\n",
      "    model.compile(optimizer=Adam(), loss='mse', metrics=['mae'])\n",
      "    return model\n",
      "\n",
      "def train_model(model, train_df, test_df):\n",
      "    # Separate features and target\n",
      "    X_train = train_df.drop(columns=['y']).values\n",
      "    y_train = train_df['y'].values\n",
      "    \n",
      "    # Standardize features\n",
      "    scaler = StandardScaler()\n",
      "    X_train_scaled = scaler.fit_transform(X_train)\n",
      "    \n",
      "    # Reshape for Conv1D input\n",
      "    X_train_scaled = X_train_scaled.reshape(X_train_scaled.shape[0], X_train_scaled.shape[1], 1)\n",
      "    \n",
      "    # Train the model\n",
      "    model.fit(X_train_scaled, y_train, epochs=100, batch_size=32, validation_split=0.2)\n",
      "    \n",
      "    return model\n",
      "\n",
      "def evaluate_model(model, test_df):\n",
      "    # Separate features and target\n",
      "    X_test = test_df.drop(columns=['y']).values\n",
      "    y_true = test_df['y'].values\n",
      "    \n",
      "    # Standardize features using the same parameters as training\n",
      "    scaler = StandardScaler()\n",
      "    X_test_scaled = scaler.fit_transform(X_test)\n",
      "    \n",
      "    # Reshape for Conv1D input\n",
      "    X_test_scaled = X_test_scaled.reshape(X_test_scaled.shape[0], X_test_scaled.shape[1], 1)\n",
      "    \n",
      "    # Predict on the test data\n",
      "    y_pred = model.predict(X_test_scaled)\n",
      "    \n",
      "    # Calculate metrics\n",
      "    rmse = np.sqrt(mean_squared_error(y_true, y_pred))\n",
      "    mae = mean_absolute_error(y_true, y_pred)\n",
      "    r2 = r2_score(y_true, y_pred)\n",
      "    \n",
      "    # Return metrics as dictionary\n",
      "    return {'RMSE': rmse, 'MAE': mae, 'R^2': r2}\n",
      "\n",
      "def main_model_function(train_path, test_path):\n",
      "    # Load and prepare data\n",
      "    train_df, test_df = load_and_prepare_data(train_path, test_path)\n",
      "    \n",
      "    # Calculate Bruun model feature\n",
      "    train_df = calculate_bruun_model_feature(train_df)\n",
      "    test_df = calculate_bruun_model_feature(test_df)\n",
      "    \n",
      "    # Calculate Dean model feature\n",
      "    train_df = calculate_dean_model_feature(train_df)\n",
      "    test_df = calculate_dean_model_feature(test_df)\n",
      "    \n",
      "    # Determine input shape\n",
      "    input_shape = train_df.drop(columns=['y']).shape[1]\n",
      "    output_shape = 1  # Single value regression target\n",
      "    \n",
      "    # Build deep learning model\n",
      "    model = build_deep_learning_model(input_shape, output_shape)\n",
      "    \n",
      "    # Train model\n",
      "    model = train_model(model, train_df, test_df)\n",
      "    \n",
      "    # Evaluate model\n",
      "    performance_metrics = evaluate_model(model, test_df)\n",
      "    \n",
      "    return performance_metrics\n"
     ]
    }
   ],
   "source": [
    "checking_code = \"code_0\"\n",
    "print(best_codes[checking_code][\"code\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "c3fab755",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load E:\\python_project_new\\AI4SLCDP\\MBPP_data\\MBPP_ET.jsonl\n",
    "import json\n",
    "import os\n",
    "\n",
    "def load_jsonl(file_path):\n",
    "    data = []\n",
    "    with open(file_path, 'r') as file:\n",
    "        for line in file:\n",
    "            data.append(json.loads(line))\n",
    "    return data\n",
    "\n",
    "mbpp_et_data = load_jsonl(\"E:\\\\python_project_new\\\\AI4SLCDP\\\\MBPP_data\\\\humaneval.jsonl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6ac75fb7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "164\n"
     ]
    }
   ],
   "source": [
    "print(len(mbpp_et_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4e4aa3b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_prompt = \"\"\"I am working on a code generation task:\n",
    "<task description>\n",
    "Please refer to the Bruun model and the Dean model to build a mathematical model to discuss the change of sea level height with respect to the distance from the starting point, and build a deep learning model based on our data. The following is our data path and structure:\n",
    "\n",
    "Data File Paths:\n",
    "\n",
    "Training Data: \"beach_profile_data/processed_data/beachdata_train.xlsx\"\n",
    "Test Data: \"beach_profile_data/processed_data/beachdata_test.xlsx\"\n",
    "Data Structure Insights:\n",
    "The training and testing datasets contain following columns:\n",
    "\n",
    "[\n",
    "    \"Profile Azimuth (º)\",\n",
    "    \"Profile Length (m)\",\n",
    "    \"Longitude (º)\",\n",
    "    \"Latitude (º)\",\n",
    "    \"x\",\n",
    "    \"y\",\n",
    "    \"Intertidal Slope\",\n",
    "    \"Mean Grain Size\",\n",
    "    \"Mean Grain Size (Mz)\",\n",
    "    \"Sorting Coefficient (δ)\",\n",
    "    \"Skewness (Sk)\",\n",
    "    \"Kurtosis (Kg)\",\n",
    "    \"Corresponding Tidal Station Index\",\n",
    "    \"Annual Mean Spring Tidal Range (m)\",\n",
    "    \"Annual Mean Tidal Range (m)\",\n",
    "    \"Dominant Wave Direction\",\n",
    "    \"Dominant Wave Direction (Degrees)\",\n",
    "    \"Frequency\",\n",
    "    \"Mean Wave Height (m)\",\n",
    "    \"Period (s)\",\n",
    "    \"Relative Tidal Range (RTR)\",\n",
    "    \"Dimensionless Settling Velocity (Ω)\",\n",
    "    \"High Tide Sediment Settling Velocity ωs (cm/s)\",\n",
    "    \"Breaker Wave Height Hb (m)\",\n",
    "    \"Deep Water Wave Height Hd (m)\",\n",
    "    \"Annual Mean Period (s)\"\n",
    "]\n",
    "\n",
    "create the code directly and The script should not run the main training or evaluation logic directly when the script file is executed. Instead, it should define all necessary functions.\n",
    "<task description>\n",
    "\n",
    "I need you to briefly introduce the features of our data, and then think about how can we use them for our project.\"\"\"\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "8042266e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "To begin with, let's introduce the features of your data, provided in your datasets. These features are crucial for building your mathematical and deep learning models for sea-level change prediction:\n",
      "\n",
      "1. **Profile Azimuth (°)**: This is the orientation of the beach profile in degrees. It can provide insight into the direction the profile faces relative to wave and wind direction, which can influence sediment transport.\n",
      "\n",
      "2. **Profile Length (m)**: The distance over which the profile is measured. This can indicate the scale of coastal processes and extent of possible erosion or deposition.\n",
      "\n",
      "3. **Longitude (°) and Latitude (°)**: The geographical coordinates of the profile. These are essential for relating the data to specific locations and for potential integration with geographic information systems (GIS).\n",
      "\n",
      "4. **x, y**: Possible transformed or projected coordinates used for spatial analysis.\n",
      "\n",
      "5. **Intertidal Slope**: The slope of the beach between high and low tide marks. This can affect wave energy dissipation and sediment dynamics.\n",
      "\n",
      "6. **Mean Grain Size and Mean Grain Size (Mz)**: Two measures describing the average size of sediment grains. Grain size influences sediment transport and settling velocity.\n",
      "\n",
      "7. **Sorting Coefficient (δ)**, **Skewness (Sk)**, **Kurtosis (Kg)**: Statistical measures of sediment grain size distribution. These features help understand sediment texture and stability.\n",
      "\n",
      "8. **Corresponding Tidal Station Index**: Links the profile to a specific tidal station, which is useful for tide-related data integration.\n",
      "\n",
      "9. **Annual Mean Spring Tidal Range (m)** and **Annual Mean Tidal Range (m)**: Average ranges of spring and regular tides, respectively. Important for understanding tidal influences on beach profile changes.\n",
      "\n",
      "10. **Dominant Wave Direction and Dominant Wave Direction (Degrees)**: Describes the primary direction from which waves approach. This influences sediment transport direction and profile shape.\n",
      "\n",
      "11. **Frequency**: Likely the frequency of occurrence of certain wave directions/events.\n",
      "\n",
      "12. **Mean Wave Height (m)** and **Breaker Wave Height Hb (m)**: Describes average and breaker wave heights, affecting energy impacting the beach and erosion potential.\n",
      "\n",
      "13. **Period (s)** and **Annual Mean Period (s)**: Time intervals between waves; relates to wave energy and sediment movement.\n",
      "\n",
      "14. **Relative Tidal Range (RTR)**: Represents spring tidal range normalized by the mean tide range, indicating tidal dominance.\n",
      "\n",
      "15. **Dimensionless Settling Velocity (Ω)** and **High Tide Sediment Settling Velocity ωs (cm/s)**: Relate to the movement of suspended sediments, affecting sediment deposition rates.\n",
      "\n",
      "16. **Deep Water Wave Height Hd (m)**: Height of waves in deep water before reaching the shore, relevant for wave energy onshore.\n",
      "\n",
      "These features provide a comprehensive understanding of the physical environment and processes at your sites. Here’s how we can use them in your project:\n",
      "\n",
      "### Mathematical Model:\n",
      "- **Bruun Model**: Typically used to predict shoreline retreat due to sea level rise based on profile slope and sediment dynamics. Here, features like the intertidal slope, sediment characteristics, and wave data can be integrated with this model to estimate profile changes.\n",
      "  \n",
      "- **Dean Model**: Focuses on equilibrium profile shape influenced by sediment characteristics and wave conditions. Utilize mean grain size, wave heights, and periods to simulate profile adaptations to changing sea levels.\n",
      "\n",
      "### Deep Learning Model:\n",
      "- **Feature Selection**: Use all relevant features to predict sea-level changes and profile alterations. This would involve correlating how changes in tidal range, wave conditions, and sediment characteristics guide profile changes.\n",
      "\n",
      "- **Data Preprocessing**: Normalize or standardize numerical values. Given lat/lon, proper spatial temporal integration might be necessary.\n",
      "\n",
      "- **Model Architecture**: Design a neural network that can understand non-linear relationships within these features to predict outcomes. Consider LSTM or GRU networks if temporal predictions are necessary (e.g., based on time series of profile changes).\n",
      "\n",
      "- **Model Training**: Input training data to teach the model relationships between conditions and sea-level profile changes. Test and validate with a separate dataset to ensure accuracy and generalization.\n",
      "\n",
      "By deeply leveraging the structures and relationships encoded in these features, you can build a sophisticated model to simulate and predict the impact of changing sea levels on beach morphology.\n"
     ]
    }
   ],
   "source": [
    "output = await llm.LLM_response_async(test_prompt, model=\"gpt-4o\")\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ed3a87b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<Type>\n",
      "correctness\n",
      "</Type>\n",
      "<Planning>\n",
      "Test the function with a simple input where a valid subsequence exists and the product is within the limit.\n",
      "</Planning>\n",
      "<Code>\n",
      "def test_case(func):\n",
      "    nums = [1, 2, 3]\n",
      "    k = 2\n",
      "    limit = 10\n",
      "    expected_output = 6\n",
      "    result = func(nums, k, limit)\n",
      "    return result == expected_output\n",
      "</Code>\n",
      "\n",
      "<separator>\n",
      "\n",
      "<Type>\n",
      "correctness\n",
      "</Type>\n",
      "<Planning>\n",
      "Test the function with an input where no valid subsequence exists.\n",
      "</Planning>\n",
      "<Code>\n",
      "def test_case(func):\n",
      "    nums = [0, 2, 3]\n",
      "    k = -5\n",
      "    limit = 12\n",
      "    expected_output = -1\n",
      "    result = func(nums, k, limit)\n",
      "    return result == expected_output\n",
      "</Code>\n",
      "\n",
      "<separator>\n",
      "\n",
      "<Type>\n",
      "correctness\n",
      "</Type>\n",
      "<Planning>\n",
      "Test the function with an input where multiple valid subsequences exist.\n",
      "</Planning>\n",
      "<Code>\n",
      "def test_case(func):\n",
      "    nums = [2, 2, 3, 3]\n",
      "    k = 0\n",
      "    limit = 9\n",
      "    expected_output = 9\n",
      "    result = func(nums, k, limit)\n",
      "    return result == expected_output\n",
      "</Code>\n"
     ]
    }
   ],
   "source": [
    "output = llm.LLM_response(test_prompt, model=\"gpt-3.5-turbo\")\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42127fc5",
   "metadata": {},
   "source": [
    "### extraction test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5ca4dc70",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_temp = \"\"\"<Type>\n",
    "correctness\n",
    "</Type>\n",
    "<Planning>\n",
    "This test case will cover the scenario where a subsequence with an alternating sum equal to k exists and the product does not exceed the limit.\n",
    "\n",
    "I will create a test case with nums = [1, 2, 3], k = 2, limit = 10.\n",
    "The expected output is 6 because the subsequence [1, 2, 3] has an alternating sum of 2 and the product of these numbers is 6, which is within the limit.\n",
    "</Planning>\n",
    "<Code>\n",
    "def test_case(func):\n",
    "    nums = [1, 2, 3]\n",
    "    k = 2\n",
    "    limit = 10\n",
    "    expected_output = 6\n",
    "\n",
    "    result = func(nums, k, limit)\n",
    "    \n",
    "    return result == expected_output\n",
    "</Code>\n",
    "\n",
    "<separator>\n",
    "\n",
    "<Type>\n",
    "edge_case\n",
    "</Type>\n",
    "<Planning>\n",
    "This test case will cover the scenario where no subsequence with an alternating sum equal to k exists.\n",
    "\n",
    "I will create a test case with nums = [0, 2, 3], k = -5, limit = 12.\n",
    "The expected output is -1 because there is no subsequence with an alternating sum of -5.\n",
    "</Planning>\n",
    "<Code>\n",
    "def test_case(func):\n",
    "    nums = [0, 2, 3]\n",
    "    k = -5\n",
    "    limit = 12\n",
    "    expected_output = -1\n",
    "\n",
    "    result = func(nums, k, limit)\n",
    "    \n",
    "    return result == expected_output\n",
    "</Code>\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "418fb7dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_test_cases(output_text):\n",
    "    \"\"\"\n",
    "    Extracts test cases from LLM output text with flexible tag handling.\n",
    "    Supports case-insensitive tags, missing <Type> tags, and multi-separators.\n",
    "    \"\"\"\n",
    "    import re\n",
    "    test_cases = {}\n",
    "\n",
    "    tags_to_protect_newlines_in = ['type', 'planning', 'reasoning', 'code', 'test_function']\n",
    "    def preprocess_text(text):\n",
    "        placeholder = \"###NL###\"\n",
    "        \n",
    "        # This inner function will be used by re.sub to process the content of matched tags\n",
    "        def replace_newlines_in_content(match_obj):\n",
    "            opening_tag = match_obj.group(1)\n",
    "            content = match_obj.group(2)\n",
    "            closing_tag = match_obj.group(3)\n",
    "            \n",
    "            processed_content = content.replace(\"\\n\", placeholder)\n",
    "            return f\"{opening_tag}{processed_content}{closing_tag}\"\n",
    "\n",
    "        # Iteratively protect newlines within the specified XML/HTML-like tags\n",
    "        for tag_name_to_protect in tags_to_protect_newlines_in:\n",
    "            tag_pattern = re.compile(\n",
    "                r'(<\\s*' + re.escape(tag_name_to_protect) + r'\\b[^>]*>)(.*?)(<\\s*/\\s*' + re.escape(tag_name_to_protect) + r'\\s*>)',\n",
    "                flags=re.IGNORECASE | re.DOTALL\n",
    "            )\n",
    "            text = tag_pattern.sub(replace_newlines_in_content, text)\n",
    "        \n",
    "        # Also protect newlines within ```python ... ``` markdown code blocks (as before)\n",
    "        def replace_newlines_in_markdown_code(match_obj):\n",
    "            block = match_obj.group(0) # The whole matched block ```python ... ```\n",
    "            return block.replace(\"\\n\", placeholder)\n",
    "        \n",
    "        markdown_code_pattern = re.compile(r'(```python.*?```)', flags=re.IGNORECASE | re.DOTALL)\n",
    "        text = markdown_code_pattern.sub(replace_newlines_in_markdown_code, text)\n",
    "        \n",
    "        return text, placeholder\n",
    "\n",
    "    # 预处理：隐藏代码块内的换行符\n",
    "    modified_text, placeholder = preprocess_text(output_text)\n",
    "    \n",
    "    # 分块：使用<separator>标签 或 连续空行分块\n",
    "    split_pattern = r'(?:<\\s*/\\s*separator\\s*>|<\\s*separator\\s*>|<\\s*separator\\s*/>|\\n\\s*\\n\\s*)'\n",
    "    test_case_blocks = re.split(split_pattern, modified_text, flags=re.IGNORECASE)\n",
    "    test_case_blocks = [b.strip() for b in test_case_blocks if b.strip()]\n",
    "    \n",
    "    # 还原各块内被隐藏的换行符\n",
    "    test_case_blocks = [b.replace(placeholder, \"\\n\") for b in test_case_blocks]\n",
    "\n",
    "    # print(f\"共分出 {len(test_case_blocks)} 个块\")\n",
    "    for idx, block in enumerate(test_case_blocks, 1):\n",
    "        # 1. 提取 test_type\n",
    "        test_type = None\n",
    "        \n",
    "        # Case 1：通过 <type>value</type>\n",
    "        type_match = re.search(\n",
    "            r'<\\s*type\\s*>(.*?)<\\s*/\\s*type\\s*>', \n",
    "            block, \n",
    "            re.IGNORECASE | re.DOTALL\n",
    "        )\n",
    "        if type_match:\n",
    "            test_type = type_match.group(1).strip()\n",
    "        else:\n",
    "            # Case 2：判断是否有其他非已知标签标记的类型\n",
    "            known_tags = {'type', 'planning', 'code', 'reasoning', 'test_function', 'separator'}\n",
    "            for tag_match in re.finditer(r'<\\s*([^\\s>/]+)\\s*.*?>', block, re.IGNORECASE):\n",
    "                tag_name = tag_match.group(1).lower()\n",
    "                if tag_name not in known_tags:\n",
    "                    test_type = tag_name\n",
    "                    break  # 取第一个不在已知标签中的\n",
    "            \n",
    "        if not test_type:  # 若无 test_type 则跳过该块\n",
    "            continue\n",
    "        \n",
    "        # 2. 提取 reasoning（支持 <planning> 和 <reasoning>）\n",
    "        reasoning_match = re.search(\n",
    "            r'<\\s*(?:reasoning|planning)\\s*>(.*?)<\\s*/\\s*(?:reasoning|planning)\\s*>',\n",
    "            block, \n",
    "            re.IGNORECASE | re.DOTALL\n",
    "        )\n",
    "        reasoning = reasoning_match.group(1).strip() if reasoning_match else \"\"\n",
    "        \n",
    "        # 3. 提取 test_function（优先顺序：test_function 标签 > code 标签 > 独立代码块）\n",
    "        test_func = None\n",
    "        \n",
    "        # 检查 <test_function> 标签\n",
    "        test_func_match = re.search(\n",
    "            r'<\\s*test_function\\s*>(.*?)<\\s*/\\s*test_function\\s*>',\n",
    "            block, \n",
    "            re.IGNORECASE | re.DOTALL\n",
    "        )\n",
    "        if test_func_match:\n",
    "            content = test_func_match.group(1).strip()\n",
    "            code_block = re.search(r'```python\\s*(.*?)\\s*```', content, re.DOTALL)\n",
    "            test_func = code_block.group(1).strip() if code_block else content\n",
    "        else:\n",
    "            # 检查 <code> 标签\n",
    "            code_match = re.search(\n",
    "                r'<\\s*code\\s*>(.*?)<\\s*/\\s*code\\s*>',\n",
    "                block,\n",
    "                re.IGNORECASE | re.DOTALL\n",
    "            )\n",
    "            if code_match:\n",
    "                content = code_match.group(1).strip()\n",
    "                code_block = re.search(r'```python\\s*(.*?)\\s*```', content, re.DOTALL)\n",
    "                test_func = code_block.group(1).strip() if code_block else content\n",
    "            else:\n",
    "                # 检查独立代码块 (```python ... ```)\n",
    "                code_block = re.search(r'```python\\s*(.*?)\\s*```', block, re.DOTALL)\n",
    "                if code_block:\n",
    "                    test_func = code_block.group(1).strip()\n",
    "        \n",
    "        if test_type and test_func:\n",
    "            test_cases[f'test_case_{idx}'] = {\n",
    "                'test_type': test_type,\n",
    "                'purpose': reasoning,\n",
    "                'test_function': test_func\n",
    "            }\n",
    "    \n",
    "    if not test_cases:\n",
    "        # 如果没有提取到测试用例，则返回 False\n",
    "        return False\n",
    "\n",
    "    return test_cases\n",
    "\n",
    "# 测试提取函数\n",
    "extracted_cases = extract_test_cases(text_temp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "04a4ed0c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'test_case_1': {'test_type': 'correctness', 'purpose': 'This test case will cover the scenario where a subsequence with an alternating sum equal to k exists and the product does not exceed the limit.\\n\\nI will create a test case with nums = [1, 2, 3], k = 2, limit = 10.\\nThe expected output is 6 because the subsequence [1, 2, 3] has an alternating sum of 2 and the product of these numbers is 6, which is within the limit.', 'test_function': 'def test_case(func):\\n    nums = [1, 2, 3]\\n    k = 2\\n    limit = 10\\n    expected_output = 6\\n\\n    result = func(nums, k, limit)\\n    \\n    return result == expected_output'}, 'test_case_2': {'test_type': 'edge_case', 'purpose': 'This test case will cover the scenario where no subsequence with an alternating sum equal to k exists.\\n\\nI will create a test case with nums = [0, 2, 3], k = -5, limit = 12.\\nThe expected output is -1 because there is no subsequence with an alternating sum of -5.', 'test_function': 'def test_case(func):\\n    nums = [0, 2, 3]\\n    k = -5\\n    limit = 12\\n    expected_output = -1\\n\\n    result = func(nums, k, limit)\\n    \\n    return result == expected_output'}}\n"
     ]
    }
   ],
   "source": [
    "print(extracted_cases)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bb401448",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'<Think>\n",
      "The last code failed to pass the test function because the calculate_xor_pairs_sum function is not calculating the sum of all pairs of numbers correctly. It is only calculating the xor of pairs within the first k elements of the list. To fix this, we need to iterate over all pairs of numbers in the given list and calculate the xor sum accordingly.\n",
      "</Think>\n",
      "\n",
      "<Code>\n",
      "```python\n",
      "def pair_xor_Sum(nums, k):\n",
      "    # Function to calculate the sum of xor of all pairs of numbers in the given list\n",
      "    def calculate_xor_pairs_sum(nums):\n",
      "        sum_xor_pairs = 0\n",
      "        # Iterate through all pairs of numbers in the list\n",
      "        for i in range(len(nums)):\n",
      "            for j in range(i+1, len(nums)):\n",
      "                # Calculate the xor of each pair and add it to sum_xor_pairs\n",
      "                sum_xor_pairs += nums[i] ^ nums[j]\n",
      "        return sum_xor_pairs\n",
      "    \n",
      "    return calculate_xor_pairs_sum(nums)\n",
      "```\n",
      "</Code>\n",
      "\n",
      "<Planning>\n",
      "1. Define a function `pair_xor_Sum` that takes a list of numbers `nums` and an integer `k` as input.\n",
      "2. Define a nested function `calculate_xor_pairs_sum` within `pair_xor_Sum` to calculate the sum of xor of all pairs of numbers in the given list.\n",
      "3. Initialize `sum_xor_pairs` to 0.\n",
      "4. Iterate through all pairs of numbers in the `nums` list using nested loops.\n",
      "5. Calculate the xor of each pair of numbers and add it to `sum_xor_pairs`.\n",
      "6. Return the final `sum_xor_pairs`.\n",
      "7. Return the result from `calculate_xor_pairs_sum` as the final output.\n",
      "</Planning>\n",
      "\n",
      "<Main Function Name>\n",
      "pair_xor_Sum\n",
      "</Main Function Name>'\n"
     ]
    }
   ],
   "source": [
    "temp_text = \"\"\"'<Think>\\nThe last code failed to pass the test function because the calculate_xor_pairs_sum function is not calculating the sum of all pairs of numbers correctly. It is only calculating the xor of pairs within the first k elements of the list. To fix this, we need to iterate over all pairs of numbers in the given list and calculate the xor sum accordingly.\\n</Think>\\n\\n<Code>\\n```python\\ndef pair_xor_Sum(nums, k):\\n    # Function to calculate the sum of xor of all pairs of numbers in the given list\\n    def calculate_xor_pairs_sum(nums):\\n        sum_xor_pairs = 0\\n        # Iterate through all pairs of numbers in the list\\n        for i in range(len(nums)):\\n            for j in range(i+1, len(nums)):\\n                # Calculate the xor of each pair and add it to sum_xor_pairs\\n                sum_xor_pairs += nums[i] ^ nums[j]\\n        return sum_xor_pairs\\n    \\n    return calculate_xor_pairs_sum(nums)\\n```\\n</Code>\\n\\n<Planning>\\n1. Define a function `pair_xor_Sum` that takes a list of numbers `nums` and an integer `k` as input.\\n2. Define a nested function `calculate_xor_pairs_sum` within `pair_xor_Sum` to calculate the sum of xor of all pairs of numbers in the given list.\\n3. Initialize `sum_xor_pairs` to 0.\\n4. Iterate through all pairs of numbers in the `nums` list using nested loops.\\n5. Calculate the xor of each pair of numbers and add it to `sum_xor_pairs`.\\n6. Return the final `sum_xor_pairs`.\\n7. Return the result from `calculate_xor_pairs_sum` as the final output.\\n</Planning>\\n\\n<Main Function Name>\\npair_xor_Sum\\n</Main Function Name>'\"\"\"\n",
    "print(temp_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f5ebb86",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "def find_tuple_intersection(list1, list2):\n",
    "    # Sort the tuples in each list to ensure order does not affect intersection\n",
    "    set1 = {tuple(sorted(item)) for item in list1}\n",
    "    set2 = {tuple(sorted(item)) for item in list2}\n",
    "    \n",
    "    # Find the intersection of tuples in the two sets\n",
    "    intersection = set1.intersection(set2)\n",
    "    \n",
    "    return intersection\n",
    "\n",
    "def tuple_intersection(list1, list2):\n",
    "    return find_tuple_intersection(list1, list2)\n",
    "\n",
    "print(tuple_intersection([(3, 4), (5, 6), (9, 10), (4, 5)] , [(5, 4), (3, 4), (6, 5), (9, 11)]) == {(4, 5), (3, 4), (5, 6)})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1f4efc97",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Error Type', 'Error message']\n"
     ]
    }
   ],
   "source": [
    "parts = []\n",
    "parts.extend([\"Error Type\", \"Error message\"])\n",
    "print(parts)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "d8b5915f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{}\n",
      "(None, \"Compilation Error: name 'tuple_intersection' is not defined\")\n"
     ]
    }
   ],
   "source": [
    "test_fun_str = \"\"\"def test_case(func):\n",
    "    input_1 = [(3, 4), (5, 6), (9, 10), (4, 5)]\n",
    "    input_2 = [(5, 4), (3, 4), (6, 5), (9, 11)]\n",
    "    expected_output = {(4, 5), (3, 4), (5, 6)}\n",
    "\n",
    "    result = func(input_1, input_2)\n",
    "    if result == expected_output:\n",
    "        return True, \"Test passed: output matches the expected set of intersection tuples\"\n",
    "    else:\n",
    "        return False, f\"Test failed: expected {expected_output} but got {result}\"\n",
    "\n",
    "test_case(tuple_intersection)\"\"\"\n",
    "\n",
    "dataset_temp = {\n",
    "    \"code_0\": {\n",
    "        \"test_function\": test_fun_str,\n",
    "        \"test_type\": \"correctness\",\n",
    "        \"purpose\": \"Test the correctness of the tuple_intersection function with a specific input.\"\n",
    "    }\n",
    "}\n",
    "\n",
    "def _filter_test_cases(dataset):\n",
    "    # print(dataset)\n",
    "    runnable_entries = {}\n",
    "    for code_id, attributes in dataset.items():\n",
    "        test_code = attributes.get(\"test_function\", \"\")\n",
    "        try:\n",
    "            # Attempt to compile the code string to check for syntax errors.\n",
    "            # compile(test_code, \"<string>\", \"exec\")\n",
    "            local_vars = {}\n",
    "            exec(test_code, local_vars)\n",
    "            for name, obj in local_vars.items():\n",
    "                if callable(obj):\n",
    "                    runnable_entries[code_id] = attributes\n",
    "                    break\n",
    "                else:\n",
    "                    # If the object is not callable, skip this entry.\n",
    "                    continue\n",
    "            # If no exception is raised, consider the code as runnable.\n",
    "            runnable_entries[code_id] = attributes\n",
    "        except Exception as error:\n",
    "            # If an exception is raised, skip this entry.\n",
    "            continue\n",
    "    return runnable_entries\n",
    "\n",
    "def compile_code_test(code_str, main_function_name=None):\n",
    "    try:\n",
    "        local_vars = {}\n",
    "        exec(code_str, local_vars)\n",
    "        for name, obj in local_vars.items():\n",
    "            if callable(obj):\n",
    "                return obj, None\n",
    "        return None, \"No callable found in code\"\n",
    "    except Exception as e:\n",
    "        return None, f\"Compilation Error: {str(e)}\"\n",
    "\n",
    "print(_filter_test_cases(dataset_temp))\n",
    "print(compile_code_test(test_fun_str))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5b062dbe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<components>\n",
      "{\n",
      "  \"calculate_distance\": {\n",
      "    \"step_task_description\": \"Calculate the Euclidean distance between two points in a 2D plane.\",\n",
      "    \"input_format\": [[\"tuple\", (2,)]],\n",
      "    \"output_format\": [[\"float\", null]],\n",
      "    \"work_flow\": [\n",
      "      \"Extract the x and y coordinates of both points.\",\n",
      "      \"Calculate the square of the horizontal distance (x1 - x2) ^ 2.\",\n",
      "      \"Calculate the square of the vertical distance (y1 - y2) ^ 2.\",\n",
      "      \"Calculate the Euclidean distance as the square root of the sum of the above distances.\"\n",
      "    ],\n",
      "    \"test_case_generation_advise\": [\n",
      "      \"Test with points having positive and negative coordinates.\",\n",
      "      \"Test with points at the origin.\",\n",
      "      \"Test with points in different quadrants.\"\n",
      "    ]\n",
      "  },\n",
      "  \"min_Jumps\": {\n",
      "    \"step_task_description\": \"Calculate the minimum number of jumps of a given length required to reach a specific point from the origin.\",\n",
      "    \"input_format\": [[\"tuple\", (2,)], [\"float\", null]],\n",
      "    \"output_format\": [[\"float\", null]],\n",
      "    \"work_flow\": [\n",
      "      \"Calculate the distance between the given point and the origin using \\'calculate_distance\\' component.\",\n",
      "      \"Divide the distance by the given length of each jump to find the number of jumps required.\"\n",
      "    ],\n",
      "    \"test_case_generation_advise\": [\n",
      "      \"Test with different point coordinates and jump lengths.\",\n",
      "      \"Test with the same point and different jump lengths.\",\n",
      "      \"Test with the same jump length but different points.\"\n",
      "    ]\n",
      "  }\n",
      "}\n",
      "</components>\n",
      "\n",
      "<overall_plan>\n",
      "{\n",
      "  \"input_format\": [[\"tuple\", (2,)], [\"float\", null]],\n",
      "  \"output_format\": [[\"float\", null]],\n",
      "  \"components\": [\"calculate_distance\", \"min_Jumps\"],\n",
      "  \"plan\": [\n",
      "    \"Calculate the Euclidean distance between the input point and the origin using the \\'calculate_distance\\' component.\",\n",
      "    \"Calculate the minimum number of jumps required based on the given jump length using the \\'min_Jumps\\' component.\"\n",
      "  ],\n",
      "  \"test_case_generation_advise\": [\n",
      "    \"Test with various points and jump lengths to cover different scenarios.\",\n",
      "    \"Include edge cases such as points at the origin and points at different distances from the origin.\",\n",
      "    \"Test with negative values for distances and jump lengths.\"\n",
      "  ]\n",
      "}\n",
      "</overall_plan>\n"
     ]
    }
   ],
   "source": [
    "print(\"\"\"<components>\\\\n{\\\\n  \\\"calculate_distance\\\": {\\\\n    \\\"step_task_description\\\": \\\"Calculate the Euclidean distance between two points in a 2D plane.\\\",\\\\n    \\\"input_format\\\": [[\\\"tuple\\\", (2,)]],\\\\n    \\\"output_format\\\": [[\\\"float\\\", null]],\\\\n    \\\"work_flow\\\": [\\\\n      \\\"Extract the x and y coordinates of both points.\\\",\\\\n      \\\"Calculate the square of the horizontal distance (x1 - x2) ^ 2.\\\",\\\\n      \\\"Calculate the square of the vertical distance (y1 - y2) ^ 2.\\\",\\\\n      \\\"Calculate the Euclidean distance as the square root of the sum of the above distances.\\\"\\\\n    ],\\\\n    \\\"test_case_generation_advise\\\": [\\\\n      \\\"Test with points having positive and negative coordinates.\\\",\\\\n      \\\"Test with points at the origin.\\\",\\\\n      \\\"Test with points in different quadrants.\\\"\\\\n    ]\\\\n  },\\\\n  \\\"min_Jumps\\\": {\\\\n    \\\"step_task_description\\\": \\\"Calculate the minimum number of jumps of a given length required to reach a specific point from the origin.\\\",\\\\n    \\\"input_format\\\": [[\\\"tuple\\\", (2,)], [\\\"float\\\", null]],\\\\n    \\\"output_format\\\": [[\\\"float\\\", null]],\\\\n    \\\"work_flow\\\": [\\\\n      \\\"Calculate the distance between the given point and the origin using \\\\'calculate_distance\\\\' component.\\\",\\\\n      \\\"Divide the distance by the given length of each jump to find the number of jumps required.\\\"\\\\n    ],\\\\n    \\\"test_case_generation_advise\\\": [\\\\n      \\\"Test with different point coordinates and jump lengths.\\\",\\\\n      \\\"Test with the same point and different jump lengths.\\\",\\\\n      \\\"Test with the same jump length but different points.\\\"\\\\n    ]\\\\n  }\\\\n}\\\\n</components>\\\\n\\\\n<overall_plan>\\\\n{\\\\n  \\\"input_format\\\": [[\\\"tuple\\\", (2,)], [\\\"float\\\", null]],\\\\n  \\\"output_format\\\": [[\\\"float\\\", null]],\\\\n  \\\"components\\\": [\\\"calculate_distance\\\", \\\"min_Jumps\\\"],\\\\n  \\\"plan\\\": [\\\\n    \\\"Calculate the Euclidean distance between the input point and the origin using the \\\\'calculate_distance\\\\' component.\\\",\\\\n    \\\"Calculate the minimum number of jumps required based on the given jump length using the \\\\'min_Jumps\\\\' component.\\\"\\\\n  ],\\\\n  \\\"test_case_generation_advise\\\": [\\\\n    \\\"Test with various points and jump lengths to cover different scenarios.\\\",\\\\n    \\\"Include edge cases such as points at the origin and points at different distances from the origin.\\\",\\\\n    \\\"Test with negative values for distances and jump lengths.\\\"\\\\n  ]\\\\n}\\\\n</overall_plan>\"\"\".replace(\"\\\\n\", \"\\n\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04425685",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
