import random
import ast
import re
import os
import torch.multiprocessing as mp
from tqdm import tqdm
import json
import numpy as np
import openai

def parse_text(text):
    parsed_list = []
    match = re.search(r"\[(.*?)\]", text, re.S)
    if match:
        list_content = "[" + match.group(1) + "]"
        try:
            raw_list = ast.literal_eval(list_content)  
            parsed_list.extend([s for s in raw_list if not re.search(r"[\"“”‘’]", s)])
        except (SyntaxError, ValueError):
            print("❌ ")
    
    numbered_sentences = re.findall(r"(\d+)\.\s*(.*)", text)
    for _, sentence in numbered_sentences:
        if not re.search(r"[\"“”‘’]", sentence):
            parsed_list.append(sentence)
    return parsed_list

def extract_and_convert_json(text):
    text = text.strip()
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass  

    match = re.search(r'\[\s*\{.*\}\s*\]', text, re.DOTALL)
    if match:
        json_content = match.group(0)
        try:
            return json.loads(json_content)
        except json.JSONDecodeError as e:
            print("❌", e)
            return None
    else:
        print("❌")
        return None

def generate_prompt(topic):
    character_data = """{"CharacterName": "", "Attributes": {"Personality": "", "BasicAttributes": "", "CulturalBackground": ""}, "Relationships": [{"Target": "", "Closeness": "", "PowerRelation": "", "InteractionStance": ""}HOLDER]},"""
    num_speaker = random.randint(1, 4)
    first = ", {}" * (num_speaker - 1)
    character_template = [character_data.replace("HOLDER", first),] * num_speaker
    temp = "".join(character_template)[:-1]
    character_template = f"[{temp}]"
    temp = json.dumps(character_template, ensure_ascii=False, indent=4).replace(r"\"", '"')
    prompt = \
f"""You are an expert in social role modeling and emotional interaction design. Based on the following list of characters, please complete each character’s three attribute fields—**Personality**, **Basic Attributes**, and **Cultural Background**—under the theme of “{topic}”. Additionally, fill in their **interpersonal relationships** with all other characters, including:
- **Closeness** (e.g., family, friend, colleague, stranger)
- **Power Relation** (e.g., superior-subordinate, equal, none)
- **Interaction Stance** (e.g., cooperative, adversarial, neutral)

Please output the result as a **JSON array**, ensuring a consistent structure. Each character must have a fully filled `Attributes` field, and in `Relationships`, list their relationship with every other character. The stances can be varied, and educational background should **not** include specific university names.

Reference definitions:
- **Closeness**: Measures the intimacy of relationships (e.g., family, friend, colleague, stranger)
- **Power Relation**: Indicates power differences (e.g., superior-subordinate, teacher-student, equal, none)
- **Interaction Stance**: Describes the attitude of interaction (e.g., cooperative, adversarial, neutral)
- **Personality Traits**: e.g., extroverted, introverted, sensitive, impulsive
- **Basic Attributes**: e.g., gender, age, social role (such as job position)
- **Cultural Background**: e.g., place of origin, educational background, upbringing environment

Please complete the missing fields based on the following existing information:
{temp}
"""
    return prompt.strip()

def generate_and_save(topics, process_id, save_dir):
    openai.api_key = os.getenv("OPENAI_API_KEY")
    temp_file = os.path.join(save_dir, f"role_{process_id}.json")
    temp_all_json_info = {}
    with open(temp_file, "w", encoding="utf-8") as wf:
        for i, topic in enumerate(tqdm(topics, desc=f"Process-{process_id}")):
            prompt = generate_prompt(topic)
            try:
                response = openai.ChatCompletion.create(
                    model="gpt-4o",
                    messages=[
                        {"role": "user", "content": prompt},
                    ],
                    temperature=random.randint(50, 150) / 100.0,
                )
                response = response.choices[0].message.content
                result = extract_and_convert_json(response)
                temp_all_json_info[topic] = result
            except Exception as e:
                print(e)
            if i % 10 == 0:
                with open(temp_file, "w", encoding="utf-8") as wf:
                    json.dump(temp_all_json_info, wf, ensure_ascii=False, indent=4)
    with open(temp_file, "w", encoding="utf-8") as wf:
        json.dump(temp_all_json_info, wf, ensure_ascii=False, indent=4)

def merge_files(save_dir, final_output):
    final_data = {}
    for file in sorted(os.listdir(save_dir)):  
        if file.startswith("role_") and file.endswith(".json") and file.find("ds") == -1:
            file_path = os.path.join(save_dir, file)
            with open(file_path, "r", encoding="utf-8") as rf:
                try:
                    json_data = json.load(rf)  
                    final_data.update(json_data)  
                except json.JSONDecodeError as e:
                    print(f"❌ ")
            os.remove(file_path)  
    with open(final_output, "w", encoding="utf-8") as wf:
        json.dump(final_data, wf, ensure_ascii=False, indent=4)

def main():
    import time
    
    num_processes = 32 
    save_dir = "./datas/scripts/"
    context = os.path.join(save_dir, "final_context.tsv")
    final_output = os.path.join(save_dir, "final_role.json")
    with open(context, "r") as rf:
        contexts = [line.strip() for line in rf.readlines() if len(line)>2]
    total_requests = len(contexts)
    # total_requests = 6
    os.makedirs(save_dir, exist_ok=True)

    mp.set_start_method("spawn", force=True)  
    processes = []
    
    requests_per_process = total_requests // num_processes  
    for process_id in range(num_processes):
        start_idx = process_id * requests_per_process
        end_idx = (process_id + 1) * requests_per_process if process_id != num_processes - 1 else total_requests
        p = mp.Process(target=generate_and_save, args=(contexts[start_idx:end_idx], process_id, save_dir))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()  

    merge_files(save_dir, final_output)
    print(f"✅ {final_output}")

if __name__ == "__main__":
    main()