# to build ATLAS-Wiki-QKV dataset
import csv
import json
import random
import asyncio
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm

API_KEY = "***"
BASE_URL = "***"
MODEL_NAME = "***"
LKG_DIR = "***"
LKG_FILE = "triple_edges_en_simple_wiki_v0_from_json_without_emb_full_concept.csv"
# LKG_FILE = "triple_edges_cc_en_from_json_without_emb_full_concept.csv"
# LKG_FILE = "triple_edges_pes2o_abstract_from_json_without_emb_full_concept.csv"
OUTPUT_FOLDER = "***"
OUTPUT_FILE = OUTPUT_FOLDER + "atlas_wiki_qa.json"
# MAX_TOKENS = 25  # 30?
MAX_TOKENS = 15
# MIN_TOKENS = 10
MIN_TOKENS = 3
MAX_NAME_TOKENS = 4
MAX_NAME_CHARS = MAX_NAME_TOKENS * 20     # double ensure, some entities are too long without space
MIN_DESCRIPTION_TOKENS = 10
MAX_DESCRIPTION_TOKENS = 20
MAX_DESCRIPTION_CHARS = MAX_DESCRIPTION_TOKENS * 20

relation_to_nl_prompt = """
**Task:** Convert a relation phrase to its natural noun equivalent based on entity position.

**Input:**
- `relation`: A verb phrase (e.g., "is participated by", "produces")
- `position`: Either "head entity/event" (missing head) or "tail entity/event" (missing tail)

**Conversion Rules:**
1. **For missing head entity** (position = "head entity/event"):
   - If relation matches predefined pattern → Use corresponding agent noun:
     ```
     "is participated by" → "participation"
     "has purpose"        → "origin"
     ```
   - For passive relations ("is [verb]ed by"):
     - Remove "is " and " by"
     - Convert verb to agent noun:
       - Add "-er"/"-or" (e.g., "govern" → "governor")
       - Add "-ist" (e.g., "specialize" → "specialist")
       - Use irregular forms (e.g., "advise" → "advisor")

2. **For missing tail entity** (position = "tail entity/event"):
   - If relation matches predefined pattern → Use corresponding object noun:
     ```
     "is participated by" → "participant"
     "has purpose"        → "purpose"
     ```
   - For active relations (e.g., "[verb]s"):
     - Remove trailing "s" (if present)
     - Convert verb to object noun:
       - Add "-tion"/"-sion" (e.g., "produce" → "production")
       - Add "-ment" (e.g., "achieve" → "achievement")
       - Use irregular forms (e.g., "sell" → "sale")

**Output:** Only the natural noun (no explanations)

**Examples:**
- Input: `("is participated by", "head entity/event")`, Output: "participation"
- Input: `("is participated by", "tail entity/event")`, Output: "participant"
- Input: `("produces", "head entity/event")`, Output: "producer"
- Input: `("produces", "tail entity/event")`, Output: "product"
"""

relation_to_nl_prompt_short = """
**Task:** Convert relation phrase to natural noun based on missing entity position.

**Rules:**
- **Missing head**: Passive relations → agent nouns ("is governed by" → "governor", "is participated by" → "participation")
- **Missing tail**: Active relations → object nouns ("produces" → "product", "achieves" → "achievement")

**Output:** Natural noun only.

**Examples:**
- ("is participated by", "head") → "participation"
- ("is participated by", "tail") → "participant"  
- ("produces", "head") → "producer"
- ("produces", "tail") → "product"
"""


async def relation_to_nl(client, semaphore, relation, missing='head'):
    async with semaphore:
        response = await client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": relation_to_nl_prompt},
                {"role": "user", "content": f"relation: {relation}, missing: {missing}"}
            ]
        )
        return response.choices[0].message.content.strip()

def count_tokens(text):
    return len(text.split())

def count_chars(text):
    return len(text)

async def process_triples(n=150000):
    results = []
    triples_to_process = []
    already_entities = set()
    with open(f"{LKG_DIR}/{LKG_FILE}", 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        header = next(reader)
        for row in reader:
            if len(triples_to_process) >= n:
                break
            head, tail, relation = row[0], row[1], row[2]            

            # remove duplicate triples
            if head in already_entities or tail in already_entities:
                continue

            # random masking
            missing = random.choice(['head', 'tail'])
            # length check
            if missing == 'head':
                if count_tokens(tail) >= MAX_NAME_TOKENS or count_chars(tail) >= MAX_NAME_CHARS:
                    continue
                if count_tokens(head) <= MIN_DESCRIPTION_TOKENS or count_tokens(head) >= MAX_DESCRIPTION_TOKENS or count_chars(head) >= MAX_DESCRIPTION_CHARS:
                    continue
            else:
                if count_tokens(head) >= MAX_NAME_TOKENS or count_chars(head) >= MAX_NAME_CHARS:
                    continue
                if count_tokens(tail) <= MIN_DESCRIPTION_TOKENS or count_tokens(tail) >= MAX_DESCRIPTION_TOKENS or count_chars(tail) >= MAX_DESCRIPTION_CHARS:
                    continue
            already_entities.add(head)
            already_entities.add(tail)
            triples_to_process.append({'head': head, 'tail': tail, 'relation': relation, 'missing': missing})
            print(f"Processed triples: {len(triples_to_process)}", end='\r', flush=True)

    # save triples_to_process
    with open(f"{OUTPUT_FOLDER}/triples_to_process.json", 'w', encoding='utf-8') as f:
        json.dump(triples_to_process, f, ensure_ascii=False, indent=2)
    # read triples_to_process
    with open(f"{OUTPUT_FOLDER}/triples_to_process.json", 'r', encoding='utf-8') as f:
        triples_to_process = json.load(f)

    semaphore = asyncio.Semaphore(256)
    async with AsyncOpenAI(api_key=API_KEY, base_url=BASE_URL) as client:
        tasks = [relation_to_nl(client, semaphore, t['relation'], t['missing']) for t in triples_to_process]
        relation_nls = await tqdm.gather(*tasks, desc="Generating NL relations")

    for i, triple_info in enumerate(triples_to_process):
        head = triple_info['head']
        tail = triple_info['tail']
        missing = triple_info['missing']
        relation_nl = relation_nls[i]

        name = f"{tail}" if missing == 'head' else f"{head}"
        key_string = f"the {relation_nl} of {name}"
        if count_tokens(key_string) > MAX_TOKENS + 2:
            continue
        question = f"What is {key_string}?"
        answer = f"{key_string} is {head}" if missing == 'head' else f"{key_string} is {tail}"
        description_type = f"{relation_nl}"
        description = f"{head}" if missing == 'head' else f"{tail}"
        item = {
            "name": name,
            "description_type": description_type,
            "description": description,
            "Q": question,
            "A": answer,
            "key_string": key_string,
            "extended_Q": "",
            "extended_A": ""
        }
        results.append(item)
    # random shuffle
    random.shuffle(results)

    with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)

if __name__ == "__main__":
    asyncio.run(process_triples(n=150000))
