import random
import ast
import re
import os
import torch.multiprocessing as mp
from tqdm import tqdm
import json
import numpy as np
import openai

def parse_text(text):
    parsed_list = []
    match = re.search(r"\[(.*?)\]", text, re.S)
    if match:
        list_content = "[" + match.group(1) + "]"
        try:
            raw_list = ast.literal_eval(list_content) 
            parsed_list.extend([s for s in raw_list if not re.search(r"[\"“”‘’]", s)])
        except (SyntaxError, ValueError):
            print("❌")
    
    numbered_sentences = re.findall(r"(\d+)\.\s*(.*)", text)
    for _, sentence in numbered_sentences:
        if not re.search(r"[\"“”‘’]", sentence):
            parsed_list.append(sentence)
    return parsed_list

def generate_and_save(descs_part, process_id, save_dir):
    samples = [
        "Some people are talking in the noisy waiting hall of a train station, surrounded by a bustling crowd.",
        "In the quiet reading room of a library, several people are discussing in low voices to avoid disturbing others.",
        "They are walking and talking on a busy street, with constant honking of vehicles.",
        "In a meeting room, multiple people are seated around discussing, and the clock on the wall shows the meeting is about to end.",
        "They are talking in a corner of a café, with gentle background music.",
        "On a park bench outdoors, several people are chatting as a breeze blows gently.",
        "They are conversing while dining in a restaurant, with waiters passing by from time to time.",
        "In the office break room, several people are chatting, accompanied by the sound of the coffee machine.",
        "They are talking in the gym rest area, surrounded by the clashing sounds of equipment.",
        "In the airport waiting lounge, several people are discussing, with flight announcements being broadcast continuously."
    ]
    openai.api_key = os.getenv("OPENAI_API_KEY")
    temp_file = os.path.join(save_dir, f"env_{process_id}.tsv")
    with open(temp_file, "w", encoding="utf-8") as wf:
        for i, desc in enumerate(tqdm(descs_part, desc=f"Process-{process_id}")):
            try:
                rand_idxs = [
                    random.randint(0, len(samples)-1),
                    random.randint(0, len(samples)-1),
                    random.randint(0, len(samples)-1),
                    random.randint(0, len(samples)-1),
                ]
                prompt = f"""
                You are an expert in modeling conversational scenarios. Please generate a series of **descriptions of dialogue scenes that include only the "dialogue environment and external factors"**. 
                Note that the generated content **must not include or imply any of the following information**:
                - Interpersonal relationships and character traits (e.g., age, gender, identity, occupation, personality, social status, cultural background, physical or emotional states, etc.)
                - Dialogue content and context (e.g., conversation topics, communication goals, number of turns, contextual emotions, reasoning, requests, complaints, etc.)
                - Nonverbal interaction details (e.g., gestures, facial expressions, eye contact, distance, feedback responses, etc.)

                You should focus solely on the following "dialogue environment and external factors":
                - Physical environment: e.g., quiet/noisy, public/private space, meeting room, café, inside a car, outdoors, etc.
                - Time pressure: whether the participants are under time constraints or in a hurry
                - Other relevant technical conditions or external environmental changes

                Please output **25 descriptions of multi-person human dialogue scenarios in Chinese**, each as a single sentence. 
                Avoid using professions or specific identities as subjects; instead, use terms like "多人" (multiple people) or "她们" (they). 
                The scenarios should be concise, realistic, and diverse, with each one reflecting a different external factor or combination thereof.

                Finally, format all scenarios as a Python list[str], like:
                [
                    "{samples[rand_idxs[0]]}",
                    "{samples[rand_idxs[1]]}",
                    "{samples[rand_idxs[2]]}",
                    ...
                ]
                """
                response = openai.ChatCompletion.create(
                    model="gpt-4o",
                    messages=[
                        {"role": "user", "content": prompt},
                    ],
                    temperature=random.randint(50, 150) / 100.0,
                )
                response = response.choices[0].message.content
                parsed_list = parse_text(response)
                for line in parsed_list:
                    print(line.strip(), file=wf, flush=True)
            except Exception as e:
                print(e)

def merge_files(save_dir, final_output):
    lines_seen = set()
    with open(final_output, "w", encoding="utf-8") as wf:
        for file in sorted(os.listdir(save_dir)): 
            if file.startswith("env_") and file.endswith(".tsv"):
                with open(os.path.join(save_dir, file), "r", encoding="utf-8") as rf:
                    for line in rf:
                        if line not in lines_seen:
                            wf.write(line)
                            lines_seen.add(line)
                os.remove(os.path.join(save_dir, file)) 

# **多进程启动**
def main():
    total_requests = 1500 
    num_processes = 32  
    save_dir = "./datas/scripts/"
    final_output = os.path.join(save_dir, "final_env.tsv")

    os.makedirs(save_dir, exist_ok=True)

    mp.set_start_method("spawn", force=True)  
    processes = []
    
    requests_per_process = total_requests // num_processes  
    for process_id in range(num_processes):
        start_idx = process_id * requests_per_process
        end_idx = (process_id + 1) * requests_per_process if process_id != num_processes - 1 else total_requests
        p = mp.Process(target=generate_and_save, args=(np.arange(start_idx, end_idx), process_id, save_dir))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()  

    merge_files(save_dir, final_output)
    print(f"✅ done and save to {final_output}")

if __name__ == "__main__":
    main()