from tusoai import initialize_llm

### SETUP

api_key = #YOUR API KEY HERE

client = initialize_llm.init(api_key)

LLM_MODEL = 'gpt-4o-mini'


################################### EDIT HERE
task_description = "single cell RNA-seq imputation"
data_available = "an AnnData object"
features_available = None
initial_file = 'denoise_initial.py'
filename = 'denoise_tusoml'

hints = ['Make sure to store the denoised data in adata.obsm["denoised"].',
         'Keep the function header, input, output the same.']
######################################         
         
         
### LITERATURE PARSING ###

from tusoai import extract_literature

papers = extract_literature.run_extraction(task_description, top_n=10)
for paper in papers:
    print(f"\n=== {paper['title']} ===\n")
    if paper["abstract"]:
        print("ABSTRACT:", paper["abstract"][:40], "...\n")
    print("METHOD LINES:", len(paper["method_lines"]))
    print("SECTIONS:", len(paper["sections"]))

from tusoai import summarize_literature

summaries = summarize_literature.summarise_papers(papers, client, LLM_MODEL)

import json

with open(f"{task_description}_summaries.json", "w") as f:
    json.dump(summaries, f, indent=4)

import json
with open(f"{task_description}_summaries.json", "r") as f:
    summaries = json.load(f)

### KNOWLEDGE TREE BUILDING ###

from tusoai import construct_categories

num_cat=10
categories = construct_categories.get_task_categories(task_description, data_available, num_cat=num_cat,
                                client=client, model=LLM_MODEL)
print("Categories:", categories)


categories = construct_categories.refine_categories_with_summaries(categories,
                                                                   summaries,
                                                                   task_description,
                                                                   data_available,
                                                                   client=client,
                                                                   model=LLM_MODEL)
print(categories)

from tusoai import construct_prompts

to_generate = 10

prompts_dict = construct_prompts.generate_prompts_for_categories(task_description,
                                                                data_available,
                                                                categories,
                                                                to_generate,
                                                                client,
                                                                 model=LLM_MODEL
                                                                )

# Pretty print
for cat, plist in prompts_dict.items():
    print(f"\n### {cat} ({len(plist)}) ###")
    for p in plist:   # show first 5 for brevity
        print("-", p)

n_new_min = 5
n_new_max = 15

new_prompt_dict = construct_prompts.generate_prompts_from_summaries(
    summaries,
    categories,
    data_available,
    existing_prompts=prompts_dict,   # re-use previously generated prompts
    n_new_min=n_new_min,
    n_new_max=n_new_max,
    client=client,
    model=LLM_MODEL
)

# show sample
for title, cat_map in new_prompt_dict.items():
    print(f"\n### {title}")
    for cat, plist in cat_map.items():
        print(f"{cat} ({len(plist)}) ?", plist[:3])
    break  # show only first paper

for paper_title, cat_prompts in new_prompt_dict.items():
    for cat, new_list in cat_prompts.items():
        if cat not in prompts_dict:
            prompts_dict[cat] = []
        prompts_dict[cat].extend(new_list)
        

import json
from pathlib import Path
from typing import Dict, List

PROMPTS_PATH = Path(f"{task_description}_prompts.json")

def save_prompts_dict(prompts_dict: Dict[str, List[str]]):
    with open(PROMPTS_PATH, "w", encoding="utf-8") as f:
        json.dump(prompts_dict, f, indent=2)

save_prompts_dict(prompts_dict)

import json
from pathlib import Path
from typing import Dict, List
PROMPTS_PATH = Path(f"{task_description}_prompts.json")

def load_prompts_dict() -> Dict[str, List[str]]:
    if PROMPTS_PATH.exists():
        with open(PROMPTS_PATH, "r", encoding="utf-8") as f:
            return json.load(f)
    return {}
prompts = load_prompts_dict()


### INITIAL SOLUTIONS ###

from tusoai import construct_initializations

num_init        = 5
initializations = construct_initializations.get_initializations(task_description,
                                      data_available,
                                      num_init,
                                     client=client,
                                     model=LLM_MODEL)
print("Initializations:", initializations)

initializations = construct_initializations.refine_initializations_with_summaries(
    initializations,
    summaries,
    task_description=task_description,
    data_available=data_available,
    client=client,
    model=LLM_MODEL
)


with open(f'{task_description}_initializations.json', 'w') as f:
    json.dump(initializations, f, indent=2)

from tusoai import get_probabilities

weights = get_probabilities.get_category_probabilities(
    task_description=task_description,
    data_available=data_available,
    categories=categories,
    initializations=initializations,
    client=client,
    model=LLM_MODEL
)
print("Category probabilities:", weights)

def save_probabilities(weights: Dict[str, float], task_description: str, directory: str = ".") -> str:
    """Write probabilities to {directory}/{sanitized_task_description}_probabilities.json and return the path."""
    fname = f"{task_description}_probabilities.json"
    path = f"{directory.rstrip('/')}/{fname}"
    with open(path, "w", encoding="utf-8") as f:
        json.dump(weights, f, indent=2, sort_keys=True)
    return path

def load_probabilities(path: str) -> Dict[str, float]:
    """Load probabilities dict from a JSON file."""
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    # Ensure numeric and normalized on load
    data = {k: float(v) for k, v in data.items()}
    return data


path = save_probabilities(weights, task_description)
print("Saved to:", path)

from tusoai import optimizer

with open(f'{task_description}_initializations.json', 'r') as f:
    initializations = json.load(f)


import json
from pathlib import Path

# Path to the JSON file
json_path = Path("tusoai") / "diagnostic_prompts.json"

# Load the JSON
with open(json_path, "r", encoding="utf-8") as f:
    alter_info_prompts = json.load(f)


probabilities = load_probabilities(f'{task_description}_probabilities.json')
print("Reloaded probabilities:", probabilities)


### OPTIMIZATION LOOP ###

best_model, full_history = optimizer.discover_algorithm(
    llm_model=LLM_MODEL,
    temp=0.5,
    client=client,
    prompts=prompts,
    probabilities=probabilities,
    reference_filename=initial_file,
    initialisations=initializations,
    n_generations=10000,             # Number of cluster-evolve rounds
    children_per_model=1,        # Each model spawns 2 children per generation
    bug_retries=3,
    initial_bug_fix_attempts=5,
    timeout=120,
    n_feedback_buffer=5,
    skip_timeout=True,
    alter_info_prompts=alter_info_prompts,
    drop_island_iter=2,
    prompt_samples=3,
    alter_info_samples=3,
    prompt_decay=1.1,
    hints=hints,
    filename=filename,
    use_initial=False,
    TIME_LIMIT=60 * 60 * 8,
    task_description=task_description,
    val_limit=1.0,
    debug = True
)

# Assuming best_model.code is a string
with open(f"{filename}_best_tusoml.py", "w", encoding="utf-8") as f:
    f.write(best_model.code)

# Define a standalone to_dict function
def modelrecord_to_dict(record):
    return {
        "code": record.code,
        "file": str(record.file),
        "accuracy": record.accuracy,
        "model_info": record.model_info,
        "lineage": record.lineage
    }

# Convert all ModelRecord instances to dictionaries using the standalone function
serializable_history = {
    outer_k: {
        inner_k: [modelrecord_to_dict(record) for record in record_list]
        for inner_k, record_list in inner_v.items()
    } for outer_k, inner_v in full_history.items()
}
with open(f"{filename}_full_history.json", "w") as f:
    json.dump(serializable_history, f, indent=4)
