import openai
import random
import ast
import re
import os
import torch.multiprocessing as mp
from tqdm import tqdm
import json
import numpy as np
import string

def parse_text(text):
    parsed_list = []
    match = re.search(r"\[(.*?)\]", text, re.S)
    if match:
        list_content = "[" + match.group(1) + "]"
        try:
            raw_list = ast.literal_eval(list_content)  
            parsed_list.extend([s for s in raw_list if not re.search(r"[\"“”‘’]", s)])
        except (SyntaxError, ValueError):
            print("❌")
    
    numbered_sentences = re.findall(r"(\d+)\.\s*(.*)", text)
    for _, sentence in numbered_sentences:
        if not re.search(r"[\"“”‘’]", sentence):
            parsed_list.append(sentence)
    return parsed_list

import json
import ast
import re

def extract_and_convert_json(text):
    def clean_markdown_code_block(t):
        lines = t.strip().splitlines()
        if lines and lines[0].strip().startswith("```") and lines[-1].strip().startswith("```"):
            return "\n".join(lines[1:-1])
        return t

    def sanitize_text(t):
        t = t.replace('“', '"').replace('”', '"').replace('‘', "'").replace('’', "'")

        t = re.sub(r'(?<!\\)\'', '"', t)  
        t = re.sub(r'"(\s*:\s*)"', '":', t) 
        t = re.sub(r'"\s*:\s*"\s*:\s*', '": "', t) 
        t = re.sub(r'"speed"\s*:\s*"[^"]*:\s*[^"]*"', '"speed": "1.0"', t)
        t = re.sub(r'"emotion"\s*:\s*"[^"]*:\s*[^"]*"', '"emotion": "neutral"', t)

        t = re.sub(r'\),\s*\{', '}, {', t)
        t = re.sub(r'\},\s*\]', '}]', t)
        t = re.sub(r'\[\s*,', '[', t)
        return t

    def try_parse_json(t):
        try:
            return json.loads(t)
        except:
            return None

    def try_eval(t):
        try:
            obj = ast.literal_eval(t)
            return json.loads(json.dumps(obj, ensure_ascii=False))
        except:
            return None

    def extract_array_like_chunks(t):
        blocks = re.findall(r'\[\s*\{.*?\}\s*\]', t, re.DOTALL)
        results = []
        for b in blocks:
            fixed = sanitize_text(b)
            result = try_parse_json(fixed)
            if not result:
                result = try_eval(fixed)
            if result:
                results.extend(result)
        return results if results else None

    text = clean_markdown_code_block(text).strip()

    res = try_parse_json(text)
    if res: return res

    sanitized = sanitize_text(text)
    res = try_parse_json(sanitized)
    if res: return res

    res = try_eval(text)
    if res: return res

    res = extract_array_like_chunks(text)
    if res: return res

    print("❌")
    print(text)
    return None

def generate_dialogue_template(num_speakers, total_turns, environment, speakers, context, character_data):
    assert total_turns >= num_speakers

    emotions = ["sad", "angry", "neutral", "happy", "surprised"]

    if not speakers:
        speakers = [f"role{ch}" for ch in string.ascii_uppercase[:num_speakers]]

    dialogue = []
    last_speaker = None

    initial_speakers = speakers[:]
    random.shuffle(initial_speakers)

    for speaker in initial_speakers:
        num_sentences = random.randint(2, 5)
        text_list = []

        for _ in range(num_sentences):
            text_list.append({
                "lines": "",
                "emotion": "",
                "speed": "0.5~2.0",
            })

        dialogue.append({
            "speaker": speaker,
            "text": text_list
        })

        last_speaker = speaker

    remaining_turns = total_turns - num_speakers

    for _ in range(remaining_turns):
        available_speakers = [s for s in speakers if s != last_speaker]
        speaker = random.choice(available_speakers)
        last_speaker = speaker

        num_sentences = random.randint(2, 5)
        text_list = []

        for _ in range(num_sentences):
            text_list.append({
                "lines": "",
                "emotion": "",
                "speed": "0.5~2.0",
            })

        dialogue.append({
            "speaker": speaker,
            "text": text_list
        })

    dialogue_json = json.dumps(dialogue, ensure_ascii=False, indent=2)
    prompt = f"""
You are a scriptwriter tasked with creating emotionally expressive **single-sentence dialogues with internal emotion shifts**. Your output should be grounded in the following:  
- **Dialogue environment and external factors**,  
- **Dialogue content and situational context**,  
- **Interpersonal relationships and character traits**.

# Output Format Template
Each dialogue entry consists of **a list of sentence segments**, where **each segment is labeled with its corresponding emotion and speaking speed**. The entire list represents a single sentence spoken by a character.

Example:
{dialogue_json}

# Key Task Requirements
- There are {num_speakers} characters: {', '.join(speakers)}
- Dialogue alternates between speakers; **no speaker may speak twice in succession**
- Each sentence must be internally segmented (2~4 segments) and exhibit **clear emotion transitions**
- Each segment must include:
  - **lines_seg**: a span of 2~4 words, with punctuation only at the end of the last segment
  - **emotion**: the expressed emotion in this segment, chosen from: {emotions}
  - **speed**: the speaking rate for this segment (range: 0.5 to 2.0, where 0.5 = very fast, 1 = normal, 2.0 = very slow)

# Dialogue Environment and External Factors
{environment}

# Dialogue Content and Context
{context}

# Interpersonal Relationships and Character Traits
{character_data}
"""
    return prompt

def generate_and_save(role_keys, envs, contexts, process_id, save_dir, start_idx, json_data):
    openai.api_key = os.getenv("OPENAI_API_KEY")
    temp_file = os.path.join(save_dir, f"whole_lines_{process_id}.json")
    temp_all_json_info = {}
    with open(temp_file, "w", encoding="utf-8") as wf:
        for i, (role_key, env, context) in enumerate(tqdm(zip(role_keys, envs, contexts), desc=f"Process-{process_id}", total=len(role_keys))):
            character_data = json_data[role_key]
            try:
                role_names = [item["CharacterName"] for item in character_data]
                role_count = len(role_names)
                turns = random.randint(int(role_count*1.5), int(role_count*2.5))
                prompt = generate_dialogue_template(
                    num_speakers=role_count,
                    total_turns=turns,
                    context=context,
                    environment=env,
                    speakers=role_names,
                    character_data=character_data
                )
                response = openai.ChatCompletion.create(
                    model="gpt-4o",
                    messages=[
                        {"role": "user", "content": prompt},
                    ],
                )
                response = response.choices[0].message.content
                # print(response)
                result = extract_and_convert_json(response)
                temp_all_json_info[i+start_idx] = {
                    "script": result,
                    "character": character_data,
                    "env": env,
                    "context": context,
                    "spk_num": role_count,
                    "turns": turns
                }
            except Exception as e:
                print(e)
            if i % 10 == 0:
                with open(temp_file, "w", encoding="utf-8") as wf:
                    json.dump(temp_all_json_info, wf, ensure_ascii=False, indent=4)
    with open(temp_file, "w", encoding="utf-8") as wf:
        json.dump(temp_all_json_info, wf, ensure_ascii=False, indent=4)

def merge_files(save_dir, final_output):
    final_data = {}
    for file in sorted(os.listdir(save_dir)): 
        if file.startswith("whole_lines_") and file.endswith(".json"):
            file_path = os.path.join(save_dir, file)
            with open(file_path, "r", encoding="utf-8") as rf:
                try:
                    json_data = json.load(rf) 
                    final_data.update(json_data) 
                except json.JSONDecodeError as e:
                    print(f"❌: {file} : {e}")
            os.remove(file_path)  
    with open(final_output, "w", encoding="utf-8") as wf:
        json.dump(final_data, wf, ensure_ascii=False, indent=4)
    print(f"✅: {final_output}")

def is_english_only(text):
    return not re.search(r'[\u4e00-\u9fff]', text)

def process_texts(text_list):
    seen = set()
    return [t for t in map(str.strip, text_list) if t and is_english_only(t) and not (t in seen or seen.add(t))]

def main():
    import time
    num_processes = 32  
    save_dir = "./datas/scripts/whole/"
    final_output = os.path.join(save_dir, "scripts.json")
    
    inp_tsv = r"./datas/scripts/final_role.json"
    with open(inp_tsv, "r") as rf:
        role_data = json.load(rf)
    role_keys = list(role_data.keys())[:20000]
    
    envs = r"./datas/scripts/final_env.tsv"
    with open(envs, "r") as rf:
        envs = [line.strip() for line in rf.readlines() if len(line) > 2]
    envs = process_texts(envs)[:20000]
    
    contexts = r"./datas/scripts/final_context.tsv"
    with open(contexts, "r") as rf:
        contexts = [line.strip() for line in rf.readlines() if len(line) > 2]
    contexts = process_texts(contexts)[:15000]
    
    total_requests = min(len(role_keys), len(envs), len(contexts))  
    role_keys = role_keys[:total_requests]
    envs = envs[:total_requests]
    contexts = contexts[:total_requests]
    
    os.makedirs(save_dir, exist_ok=True)
    mp.set_start_method("spawn", force=True) 
    processes = []
    requests_per_process = total_requests // num_processes 
    for process_id in range(num_processes):
        start_idx = process_id * requests_per_process
        end_idx = (process_id + 1) * requests_per_process if process_id != num_processes - 1 else total_requests
        p = mp.Process(target=generate_and_save, args=(
            role_keys[start_idx:end_idx], envs[start_idx:end_idx], contexts[start_idx:end_idx],
            process_id, save_dir, start_idx, role_data
        ))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

    merge_files(save_dir, final_output)
    print(f"✅ {final_output}")

if __name__ == "__main__":
    main()