import csv
import json
import random
import asyncio
import psutil
import yaml
import os
import glob
from pathlib import Path
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm
from tqdm import tqdm as normal_tqdm

# Load configuration from YAML file
with open('../config/full_atlas_config.yaml', 'r') as file:
    config = yaml.safe_load(file)

# LLM Config
API_KEY = config['openai']['api_key']
BASE_URL = config['openai']['base_url']
MODEL_NAME = config['openai']['chat_model']
# File Config
LKG_DIR = config['kg2kv']['lkg_dir']
LKG_FILE = config['kg2kv']['lkg_file']
OUTPUT_FOLDER = config['kg2kv']['output_folder']
OUTPUT_FILE = OUTPUT_FOLDER + config['kg2kv']['output_file']
# just to ensure not too long for efficient processing
IF_CHECK = config['kg2kv']['if_check']
IF_NAIVE = config['kg2kv']['if_naive']
MAX_NAME_TOKENS = config['kg2kv']['max_name_tokens']
MAX_NAME_CHARS = config['kg2kv']['max_name_chars']
MAX_DESCRIPTION_TOKENS = config['kg2kv']['max_description_tokens']
MAX_DESCRIPTION_CHARS = config['kg2kv']['max_description_chars']


relation_to_nl_prompt = """
**Task:** Convert a relation phrase to its natural noun equivalent based on entity position.

**Input:**
- `relation`: A verb phrase (e.g., "is participated by", "produces")
- `position`: Either "head entity/event" (missing head) or "tail entity/event" (missing tail)

**Conversion Rules:**
1. **For missing head entity** (position = "head entity/event"):
   - If relation matches predefined pattern → Use corresponding agent noun:
     ```
     "is participated by" → "participation"
     "has purpose"        → "origin"
     ```
   - For passive relations ("is [verb]ed by"):
     - Remove "is " and " by"
     - Convert verb to agent noun:
       - Add "-er"/"-or" (e.g., "govern" → "governor")
       - Add "-ist" (e.g., "specialize" → "specialist")
       - Use irregular forms (e.g., "advise" → "advisor")

2. **For missing tail entity** (position = "tail entity/event"):
   - If relation matches predefined pattern → Use corresponding object noun:
     ```
     "is participated by" → "participant"
     "has purpose"        → "purpose"
     ```
   - For active relations (e.g., "[verb]s"):
     - Remove trailing "s" (if present)
     - Convert verb to object noun:
       - Add "-tion"/"-sion" (e.g., "produce" → "production")
       - Add "-ment" (e.g., "achieve" → "achievement")
       - Use irregular forms (e.g., "sell" → "sale")

**Output:** Only the natural noun (no explanations)

**Examples:**
- Input: `("is participated by", "head entity/event")`, Output: "participation"
- Input: `("is participated by", "tail entity/event")`, Output: "participant"
- Input: `("produces", "head entity/event")`, Output: "producer"
- Input: `("produces", "tail entity/event")`, Output: "product"
"""

relation_to_nl_prompt_short = """
**Task:** Convert relation phrase to natural noun based on missing entity position.

**Rules:**
- **Missing head**: Passive relations → agent nouns ("is governed by" → "governor", "is participated by" → "participation")
- **Missing tail**: Active relations → object nouns ("produces" → "product", "achieves" → "achievement")

**Output:** Natural noun only.

**Examples:**
- ("is participated by", "head") → "participation"
- ("is participated by", "tail") → "participant"  
- ("produces", "head") → "producer"
- ("produces", "tail") → "product"
"""


def check_memory():
    """CPU and Memory Usage Check"""
    memory = psutil.virtual_memory()
    if memory.percent > 85:
        print(f"\033[91mWarning: Memory usage {memory.percent}%. Shutdown...\033[0m")
        return False
    return True

def save_raw_triples_batches(n=150000, batch_size=10000):
    """Step 1: Save raw triples in batches to files"""
    print(f"=== Step 1: Saving raw triple batches ===")
    print(f"Target count: {n}, Batch size: {batch_size}")
    
    # Ensure output directory exists
    os.makedirs(OUTPUT_FOLDER, exist_ok=True)
    
    batch_num = 0
    total_processed = 0
    triples_batch = []
    
    with open(f"{LKG_DIR}/{LKG_FILE}", 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        header = next(reader)
        
        for row in reader:
            if total_processed >= n:
                break
            
            # Check memory
            if not check_memory():
                print("Insufficient memory, stopping processing")
                break
            
            head, tail, relation = row[0], row[1], row[2]
            missing = random.choice(['head', 'tail'])
            
            # Length check
            if IF_CHECK:
                if missing == 'head':
                    if count_tokens(tail) >= MAX_NAME_TOKENS or count_chars(tail) >= MAX_NAME_CHARS:
                        continue
                    if count_tokens(head) >= MAX_DESCRIPTION_TOKENS or count_chars(head) >= MAX_DESCRIPTION_CHARS:
                        continue
                else:
                    if count_tokens(head) >= MAX_NAME_TOKENS or count_chars(head) >= MAX_NAME_CHARS:
                        continue
                    if count_tokens(tail) >= MAX_DESCRIPTION_TOKENS or count_chars(tail) >= MAX_DESCRIPTION_CHARS:
                        continue
            
            triples_batch.append({'head': head, 'tail': tail, 'relation': relation, 'missing': missing})
            print(f"Processed triples: {total_processed}", end="\r")
            total_processed += 1
            
            # When batch is full, save this batch
            if len(triples_batch) >= batch_size:
                raw_batch_file = os.path.join(OUTPUT_FOLDER, f"raw_batch_{batch_num:06d}.json")
                with open(raw_batch_file, 'w', encoding='utf-8') as f:
                    json.dump(triples_batch, f, ensure_ascii=False, indent=2)
                
                print(f"Saved raw batch {batch_num}, containing {len(triples_batch)} triples")
                batch_num += 1
                triples_batch = []  # Clear batch, release memory
        
        # Save the last incomplete batch
        if triples_batch:
            raw_batch_file = os.path.join(OUTPUT_FOLDER, f"raw_batch_{batch_num:06d}.json")
            with open(raw_batch_file, 'w', encoding='utf-8') as f:
                json.dump(triples_batch, f, ensure_ascii=False, indent=2)
            print(f"Saved raw batch {batch_num}, containing {len(triples_batch)} triples")
            batch_num += 1
    
    print(f"Raw triple batches saved! Total {batch_num} batches, {total_processed} triples")
    return batch_num

def get_raw_batch_files(output_folder):
    """Get all raw batch files"""
    pattern = os.path.join(output_folder, "raw_batch_*.json")
    return sorted(glob.glob(pattern))

def get_processed_batch_files(output_folder):
    """Get all processed batch files"""
    pattern = os.path.join(output_folder, "processed_batch_*.json")
    return sorted(glob.glob(pattern))

def count_tokens(text):
    return len(text.split())

def count_chars(text):
    return len(text)

async def relation_to_nl(client, semaphore, relation, missing='head'):
    async with semaphore:
        response = await client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": relation_to_nl_prompt_short},
                {"role": "user", "content": f"relation: {relation}, missing: {missing}"}
            ]
        )
        return response.choices[0].message.content.strip()

async def process_single_raw_batch(raw_batch_file, batch_num, max_concurrent=50):
    """Step 2: Process a single raw batch file"""
    print(f"Starting to process batch {batch_num}: {raw_batch_file}")
    
    # Check memory
    if not check_memory():
        print(f"Insufficient memory, skipping batch {batch_num}")
        return False

    # Read raw batch file
    with open(raw_batch_file, 'r', encoding='utf-8') as f:
        triples_batch = json.load(f)
    
    print(f"Batch {batch_num} contains {len(triples_batch)} triples")
    
    if not IF_NAIVE:
        # Reduce concurrency
        semaphore = asyncio.Semaphore(max_concurrent)
        async with AsyncOpenAI(api_key=API_KEY, base_url=BASE_URL) as client:
            tasks = [relation_to_nl(client, semaphore, t['relation'], t['missing']) for t in triples_batch]
            relation_nls = await tqdm.gather(*tasks, desc=f"Batch {batch_num} - Generating NL relations")
    else:
        relation_nls = [t['relation'] for t in triples_batch]

    # Process results
    results = []
    for i, triple_info in enumerate(triples_batch):
        head = triple_info['head']
        tail = triple_info['tail']
        missing = triple_info['missing']
        relation_nl = relation_nls[i].replace('"', '')

        name = f"{tail}" if missing == 'head' else f"{head}"
        key_string = f"the {relation_nl} of {name}"
        question = f"What is {key_string}?"
        answer = f"{key_string} is {head}" if missing == 'head' else f"{key_string} is {tail}"
        description_type = f"{relation_nl}"
        description = f"{head}" if missing == 'head' else f"{tail}"
        
        item = {
            "name": name,
            "description_type": description_type,
            "description": description,
            "Q": question,
            "A": answer,
            "key_string": key_string,
            "extended_Q": "",
            "extended_A": ""
        }
        results.append(item)
    
    # Save processed batch results
    processed_batch_file = os.path.join(OUTPUT_FOLDER, f"processed_batch_{batch_num:06d}.json")
    with open(processed_batch_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    
    print(f"Batch {batch_num} processing completed, saved to {processed_batch_file}")
    
    # Release memory
    del triples_batch, results, relation_nls
    return True

async def process_all_raw_batches():
    """Step 2: Process all raw batch files"""
    print(f"=== Step 2: Processing all raw batch files ===")
    
    raw_batch_files = get_raw_batch_files(OUTPUT_FOLDER)
    if not raw_batch_files:
        print("No raw batch files found")
        return 0
    
    print(f"Found {len(raw_batch_files)} raw batch files")
    
    processed_count = 0
    for i, raw_batch_file in enumerate(raw_batch_files):
        batch_num = i  # Use file index as batch number
        success = await process_single_raw_batch(raw_batch_file, batch_num)
        if success:
            processed_count += 1
        
        # Check memory
        if not check_memory():
            print("Insufficient memory, stopping processing")
            break
    
    print(f"Batch processing completed! Successfully processed {processed_count} batches")
    return processed_count

def merge_processed_batches_with_cmd():
    """Step 3: Merge all processed batch files using command line tools"""
    print(f"=== Step 3: Merging processed batch files ===")
    
    processed_batch_files = get_processed_batch_files(OUTPUT_FOLDER)
    if not processed_batch_files:
        print("No processed batch files found")
        return False
    
    print(f"Found {len(processed_batch_files)} processed batch files")
    
    # Change output file extension to .jsonl
    output_file_jsonl = OUTPUT_FILE.replace('.json', '.jsonl')
    
    # First try using jq command to merge JSON files and convert to JSONL
    cmd = f"jq -c '.[]' <(jq -s 'flatten' {' '.join(processed_batch_files)}) > {output_file_jsonl}"
    print(f"Trying jq command: {cmd}")
    
    import subprocess
    try:
        result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
        if result.returncode == 0:
            print(f"jq merge completed! Results saved to {output_file_jsonl}")
            return True
        else:
            print(f"jq merge failed: {result.stderr}")
            print("Trying Python method to merge...")
            return merge_processed_batches_with_python()
    except Exception as e:
        print(f"Error executing jq command: {e}")
        print("Trying Python method to merge...")
        return merge_processed_batches_with_python()

def merge_processed_batches_with_python():
    """Fallback method: Use Python to merge all processed batch files"""
    print("Using Python streaming merge method...")
    
    processed_batch_files = get_processed_batch_files(OUTPUT_FOLDER)
    if not processed_batch_files:
        print("No processed batch files found")
        return False
    
    # Use streaming merge for all datasets
    return merge_processed_batches_streaming(processed_batch_files)


def merge_processed_batches_streaming(processed_batch_files):
    """Streaming merge method for efficient memory usage"""
    print("Using streaming merge method...")
    
    # First pass: count total records
    total_records = 0
    for batch_file in processed_batch_files:
        try:
            with open(batch_file, 'r', encoding='utf-8') as f:
                batch_data = json.load(f)
                total_records += len(batch_data)
        except Exception as e:
            print(f"Error: Cannot read batch file {batch_file}: {e}")
            continue
    
    print(f"Total records to merge: {total_records}")
    
    # Second pass: streaming merge with random sampling to JSONL format
    with open(OUTPUT_FILE, 'w', encoding='utf-8') as output_f:
        records_written = 0
        
        # Process files in random order for better shuffling
        random.shuffle(processed_batch_files)
        
        for batch_file in tqdm(processed_batch_files, desc="Streaming merge to JSONL"):
            try:
                with open(batch_file, 'r', encoding='utf-8') as f:
                    batch_data = json.load(f)
                    
                # Shuffle current batch
                random.shuffle(batch_data)
                
                for item in batch_data:
                    # Write each item as a separate line in JSONL format
                    output_f.write(json.dumps(item, ensure_ascii=False) + '\n')
                    records_written += 1
                    
                    # Check memory periodically
                    if records_written % 1000 == 0:
                        if not check_memory():
                            print("Insufficient memory during merge, stopping...")
                            exit(0)
                            
            except Exception as e:
                print(f"Error: Cannot read batch file {batch_file}: {e}")
                continue
    
    print(f"Streaming merge completed! Total {records_written} records saved to {OUTPUT_FILE}")
    return True

async def process_triples(n=150000):
    results = []
    triples_to_process = []  # Fix undefined variable
    
    with open(f"{LKG_DIR}/{LKG_FILE}", 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        header = next(reader)
        for row in reader:
            if not check_memory():
                exit(0)
            if len(triples_to_process) >= n:
                break
            head, tail, relation = row[0], row[1], row[2]
            # Remove duplicate triples
            # if head in already_entities or tail in already_entities:
            #     continue
            missing = random.choice(['head', 'tail'])
            # Length check
            if IF_CHECK:
                if missing == 'head':
                    if count_tokens(tail) >= MAX_NAME_TOKENS or count_chars(tail) >= MAX_NAME_CHARS:
                        continue
                    if count_tokens(head) >= MAX_DESCRIPTION_TOKENS or count_chars(head) >= MAX_DESCRIPTION_CHARS:
                        continue
                else:
                    if count_tokens(head) >= MAX_NAME_TOKENS or count_chars(head) >= MAX_NAME_CHARS:
                        continue
                    if count_tokens(tail) >= MAX_DESCRIPTION_TOKENS or count_chars(tail) >= MAX_DESCRIPTION_CHARS:
                        continue
            triples_to_process.append({'head': head, 'tail': tail, 'relation': relation, 'missing': missing})
            print(f"Processed triples: {len(triples_to_process)}")

    # check memory before API call
    if not check_memory():
        exit(0)

    semaphore = asyncio.Semaphore(512)
    async with AsyncOpenAI(api_key=API_KEY, base_url=BASE_URL) as client:
        tasks = [relation_to_nl(client, semaphore, t['relation'], t['missing']) for t in triples_to_process]
        relation_nls = await tqdm.gather(*tasks, desc="Generating NL relations")

    for i, triple_info in enumerate(triples_to_process):
        head = triple_info['head']
        tail = triple_info['tail']
        missing = triple_info['missing']
        relation_nl = relation_nls[i].replace('"', '')

        name = f"{tail}" if missing == 'head' else f"{head}"
        key_string = f"the {relation_nl} of {name}"
        question = f"What is {key_string}?"
        answer = f"{key_string} is {head}" if missing == 'head' else f"{key_string} is {tail}"
        description_type = f"{relation_nl}"
        description = f"{head}" if missing == 'head' else f"{tail}"
        item = {
            "name": name,
            "description_type": description_type,
            "description": description,
            "Q": question,
            "A": answer,
            "key_string": key_string,
            "extended_Q": "",
            "extended_A": ""
        }
        results.append(item)
    random.shuffle(results)
    with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)

async def main():
    """Main function: Three-step streaming processing pipeline"""
    print("\033[94m[Starting three-step streaming processing for large dataset]\033[0m")
    
    # Step 1: Save raw triple batches
    print("\033[94m[Step 1: Saving raw triple batches]\033[0m")
    batch_count = save_raw_triples_batches(n=1e7, batch_size=1e6)
    
    if batch_count > 0:
        # Step 2: Process all raw batch files
        print("\033[94m[Step 2: Processing all raw batch files]\033[0m")
        processed_count = await process_all_raw_batches()
        
        if processed_count > 0:
            # Step 3: Merge processed batch files using command line
            print("\033[94m[Step 3: Merging processed batch files]\033[0m")
            success = merge_processed_batches_with_cmd()
            
            if success:
                print("\033[92m[Three-step streaming processing completed!]\033[0m")
            else:
                print("\033[91m[Merge step failed]\033[0m")
        else:
            print("\033[91m[Batch processing failed]\033[0m")
    else:
        print("\033[91m[Raw batch saving failed]\033[0m")

if __name__ == "__main__":
    asyncio.run(main())