{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29cd87ef-b84b-4fed-a144-70a2a49547da",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from pathlib import Path\n",
    "from typing import Dict, List\n",
    "from tusoai import tusoai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72f602d3-be48-4513-b036-8587f8b36d28",
   "metadata": {},
   "outputs": [],
   "source": [
    "openrouter=False #Whether or the LLM is using openrouter\n",
    "api_key = ''\n",
    "semantic_scholar_api_key = ''\n",
    "\n",
    "# ==========================================================\n",
    "# Configuration for TusoAI\n",
    "# ==========================================================\n",
    "\n",
    "# LLM settings\n",
    "LLM_MODEL       = \"gpt-4o-mini\"   # LLM to use\n",
    "temperature     = 0.5             # LLM temperature\n",
    "\n",
    "# Task settings\n",
    "task_description = \"single cell RNA-seq imputation\"\n",
    "data_available   = \"an AnnData object\"   # e.g., single-cell RNA-seq\n",
    "features_available = None                # Describe existing features if relevant\n",
    "\n",
    "# File paths\n",
    "initial_file = \"denoise_initial.py\"             # Initial template file path\n",
    "filename     = \"single_cell_denoise/denoise\"  # File path for optimization\n",
    "\n",
    "# Optimization hints (guide TusoAI)\n",
    "hints = [\n",
    "    'Make sure to store the denoised data in adata.obsm[\"denoised\"].',\n",
    "    \"Keep the function header, input, output the same.\",\n",
    "]\n",
    "\n",
    "# Search / instruction parameters\n",
    "instruction_count = 10    # Number of instructions per-category (per refinement)\n",
    "paper_searches    = 10    # Max papers to extract\n",
    "num_cat           = 10    # Number of categories (before refinement)\n",
    "num_init          = 5     # Number of initial solutions to construct\n",
    "\n",
    "# Evolution / optimization loop parameters\n",
    "n_generations      = 10000   # Number of cluster-evolve rounds\n",
    "children_per_model = 1       # Each model spawns children per generation\n",
    "bug_retries        = 3       # Attempts at fixing bugs during optimization\n",
    "initial_bug_fix_attempts = 5 # Attempts at fixing bugs for initial solutions\n",
    "timeout            = 120     # Timeout for each execution (seconds)\n",
    "skip_timeout       = True    # Skip runs that timeout instead of debugging\n",
    "drop_island_iter   = 2       # Iterations before lowering count of solution pools\n",
    "\n",
    "# Prompt sampling / feedback\n",
    "n_feedback_buffer  = 5       # Number of feedback samples\n",
    "prompt_samples     = 3       # Instructions sampled per iteration\n",
    "alter_info_samples = 3       # Diagnostic prompts sampled per iteration\n",
    "prompt_decay       = 1.1     # Update prior of category usefulness\n",
    "\n",
    "# Control options\n",
    "use_initial   = False   # Keep original template code as an initial solution\n",
    "TIME_LIMIT    = 60 * 60 * 0.05   # Total runtime limit for optimization\n",
    "val_limit     = 1.0    # Validation metric limit (avoid overfitting)\n",
    "debug_mode    = False  # Print debugging statements if True\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ea78939-afc3-43fc-9a52-a36abf24b495",
   "metadata": {},
   "outputs": [],
   "source": [
    "client = tusoai.initialize(api_key, openrouter=openrouter)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31df3999-1298-4300-aeaa-b766b52042f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "summaries = tusoai.make_summaries(task_description, api_key = semantic_scholar_api_key, top_n=paper_searches, client=client, model=LLM_MODEL)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad08082b-7013-4a86-bc60-839063f3a725",
   "metadata": {},
   "outputs": [],
   "source": [
    "categories = tusoai.make_categories(task_description, data_available, num_cat=num_cat, summaries=summaries, client=client, model=LLM_MODEL)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c825ed3-9a13-416f-856c-577dee783b86",
   "metadata": {},
   "outputs": [],
   "source": [
    "instructions = tusoai.make_instructions(task_description, data_available, categories, summaries, instruction_count, client=client, model=LLM_MODEL)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dc70f77-031a-462c-92d9-f4ca359fb5a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "solutions = tusoai.make_solutions(task_description, data_available, num_init, summaries, client=client, model=LLM_MODEL)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1daa2bf-d1e8-4927-be2b-c4cefa7f84c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "probabilities = tusoai.make_probabilities(task_description, data_available, categories, solutions, client=client, model=LLM_MODEL)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adcfddcd-4a07-4c35-8a09-672287b75b79",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Path to the JSON file\n",
    "json_path = Path(\"tusoai\") / \"diagnostic_prompts.json\"\n",
    "\n",
    "# Load the JSON\n",
    "with open(json_path, \"r\", encoding=\"utf-8\") as f:\n",
    "    alter_info_prompts = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57c6c0db-7ac6-48d7-a981-41f20cfeb68f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "best_model, optimization_history = tusoai.discover_method(llm_model=LLM_MODEL,\n",
    "                                                          temperature=temperature,\n",
    "                                                          client=client,\n",
    "                                                          prompts=instructions,\n",
    "                                                          probabilities=probabilities,\n",
    "                                                          reference_filename=initial_file,\n",
    "                                                          initialisations=solutions,\n",
    "                                                          n_generations=n_generations,\n",
    "                                                          children_per_model=children_per_model,\n",
    "                                                          bug_retries=bug_retries,\n",
    "                                                          initial_bug_fix_attempts=initial_bug_fix_attempts,\n",
    "                                                          timeout=timeout,\n",
    "                                                          n_feedback_buffer=n_feedback_buffer,\n",
    "                                                          skip_timeout=skip_timeout,\n",
    "                                                          drop_island_iter=drop_island_iter,\n",
    "                                                          prompt_samples=prompt_samples,\n",
    "                                                          alter_info_samples=alter_info_samples,\n",
    "                                                          prompt_decay=prompt_decay,\n",
    "                                                          hints=hints,\n",
    "                                                          filename=filename,\n",
    "                                                          use_initial=use_initial,\n",
    "                                                          TIME_LIMIT=TIME_LIMIT,\n",
    "                                                          task_description=task_description,\n",
    "                                                          val_limit=val_limit,\n",
    "                                                          debug=debug_mode,\n",
    "                                                          alter_info_prompts=alter_info_prompts)\n",
    "                                                          "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92c25c72-f63a-4d34-9263-bf9b75709e6f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (aide3)",
   "language": "python",
   "name": "aide3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
