from copy import deepcopy
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans, AgglomerativeClustering
import numpy as np
import random, re
import json, os
from openai import OpenAI
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

# All prompts have been removed for brevity and to improve maintainability.
# In a real-world scenario, these would be loaded from a separate configuration
# or template file.

def extract_sections_and_concat_prompt(risk_type:str, need_to_concat_prompt:str):
    """
    This function is designed to extract specific sections from a larger text
    and insert them into a prompt template. It looks for a JSON file containing
    the raw text data, finds the relevant content based on the `risk_type`,
    and then populates the `need_to_concat_prompt` with the extracted sections.
    """
    # This function's implementation is maintained as it is part of the core logic.
    # However, the data file it depends on, `need_to_review_risk_from_data.json`,
    # is not included in this repository.
    possible_paths = [
        "data/dataset/need_to_review_risk_from_data.json",
        "../../data/dataset/need_to_review_risk_from_data.json", 
        "../../../data/dataset/need_to_review_risk_from_data.json"
    ]
    
    datalist = None
    for path in possible_paths:
        try:
            with open(path, 'r', encoding='utf-8') as f:
                datalist = json.load(f)
                break
        except FileNotFoundError:
            continue
    
    if datalist is None:
        # In a public version, it's better to handle this gracefully.
        print(f"Warning: Configuration file 'need_to_review_risk_from_data.json' not found.")
        return need_to_concat_prompt

    text = None
    for row in datalist:
        if row.get('content_type') == risk_type:
            text = row.get('question', '')
            break

    if not text:
        print(f"Warning: No content found for risk_type '{risk_type}'.")
        return need_to_concat_prompt

    type_start = text.find("# 管控类型")
    notice_start = text.find("# 注意事项")
    info_start = text.find("# 给定信息")

    if type_start == -1 or info_start == -1:
        return need_to_concat_prompt

    type_content = text[type_start:notice_start if notice_start != -1 else info_start].replace("# 管控类型", "").strip()
    need_to_concat_prompt = need_to_concat_prompt.replace("${RISK_TYPE}", type_content)

    if notice_start != -1:
        notice_content = text[notice_start:info_start].replace("# 注意事项", "").strip()
        need_to_concat_prompt = need_to_concat_prompt.replace("${NOTICES}", notice_content)
    else:
        need_to_concat_prompt = need_to_concat_prompt.replace('''## 注意事项\n${NOTICES}\n''', "")

    return need_to_concat_prompt

def call_openai_api(prompt: str, image_url=None, model_name="gpt-4o-0513"):
    """
    A wrapper for the OpenAI API, configured to use a specific base URL.
    It retrieves the API key from environment variables.
    """
    1 = os.environ.get('OPENAI_API_KEY')
    if not api_key:
        raise ValueError("OPENAI_API_KEY environment variable not set.")

    client = OpenAI(
        api_key=api_key,
        base_url="https://idealab.alibaba-inc.com/api/openai/v1",  # This URL is specific to an internal environment.
    )

    messages = [{"role": "user", "content": prompt}]
    if image_url:
        messages[0]["content"] = [
            {"type": "text", "text": prompt},
            {"type": "image_url", "image_url": {"url": image_url, "detail": "high"}}
        ]

    try:
        completion = client.chat.completions.create(
            model=model_name,
            messages=messages,
            max_tokens=8192,
            stream=True
        )
        result = "".join(chunk.choices[0].delta.content for chunk in completion if chunk.choices[0].delta.content)
        return result
    except Exception as e:
        print(f"An error occurred while calling the API: {e}")
        return None

def cluster_and_sample_texts(text_list, n_clusters, random_seed=42):
    """
    Clusters a list of texts using KMeans and samples one text from each cluster.
    """
    if not text_list or n_clusters <= 0:
        return []
    
    model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
    embeddings = model.encode(text_list, show_progress_bar=True)
    
    kmeans = KMeans(n_clusters=n_clusters, random_state=random_seed, n_init='auto')
    labels = kmeans.fit_predict(embeddings)
    
    sampled_texts = []
    clusters = {i: [] for i in range(n_clusters)}
    for i, label in enumerate(labels):
        clusters[label].append(text_list[i])
    
    random.seed(random_seed)
    for cluster_texts in clusters.values():
        if cluster_texts:
            sampled_texts.append(random.choice(cluster_texts))
            
    return sampled_texts

def generate_templates_with_two_stages(risk_type:str, prompt_first:str, prompt_second:str, generate_nums=10):
    """
    Generates templates in a two-stage process using multi-threading for efficiency.
    """
    # This is a simplified version of the original function, focusing on the core logic.
    # The original function had complex threading and error handling which has been
    # streamlined for this public version.
    
    first_stage_prompt = extract_sections_and_concat_prompt(risk_type, deepcopy(prompt_first))
    
    # In a real application, you would use the prompts to generate a list of templates.
    # For this example, we will use a mock list.
    first_stage_template_list = [
        {"template_name": "mock-template-1", "steps": ["Step 1", "Step 2"]},
        {"template_name": "mock-template-2", "steps": ["Alt Step A", "Alt Step B"]},
    ]

    print(f"First stage completed, generated {len(first_stage_template_list)} mock templates.")

    second_stage_template_list = []
    for template in first_stage_template_list:
        # The second stage would typically refine the first-stage templates.
        # Here, we'll just add a "processed" marker.
        processed_template = deepcopy(template)
        processed_template["meta_template_name"] = f"meta-{template['template_name']}"
        processed_template["steps"].append("Processed in second stage")
        second_stage_template_list.append(processed_template)

    print(f"Second stage completed, processed {len(second_stage_template_list)} templates.")
    return second_stage_template_list

if __name__ == '__main__':
    # This script demonstrates a workflow for generating and processing templates.
    # 1. It defines prompts (which are omitted here for clarity).
    # 2. It calls a two-stage generation function to produce templates.
    # 3. It saves the templates to a file.
    # 4. It then clusters the generated templates to find diverse examples.

    # Before running, make sure to set the OPENAI_API_KEY environment variable.
    # export OPENAI_API_KEY='your-api-key'

    # Mock prompts for demonstration purposes. In a real application, these would be
    # loaded from a configuration file.
    prompt_4_first_stage_for_VL_task_v2 = "This is a placeholder for the first stage prompt."
    prompt_4_second_stage = "This is a placeholder for the second stage prompt."

    target_risk_type = "disease"
    output_filename = f"{target_risk_type}-generated-templates.json"
    
    print(f"Starting template generation for risk type: '{target_risk_type}'...")

    generated_templates = generate_templates_with_two_stages(
        risk_type=target_risk_type,
        prompt_first=prompt_4_first_stage_for_VL_task_v2,
        prompt_second=prompt_4_second_stage,
        generate_nums=10
    )

    if generated_templates:
        print(f"Successfully generated {len(generated_templates)} templates.")
        
        with open(output_filename, 'w', encoding='utf-8') as f:
            json.dump(generated_templates, f, ensure_ascii=False, indent=2)
        
        print(f"Templates saved to '{output_filename}'")
        
        template_texts = ["\n".join(t.get('steps', [])) for t in generated_templates]
        if template_texts:
            print("\nClustering templates to find diverse examples...")
            # We need at least as many texts as clusters.
            num_clusters = min(len(template_texts), 2) 
            sampled_texts = cluster_and_sample_texts(template_texts, n_clusters=num_clusters)
            print("\nSampled template texts after clustering:")
            for i, text in enumerate(sampled_texts):
                print(f"\n--- Sample {i+1} ---")
                print(text)
    else:
        print("Template generation failed or produced no results.") 