import os
import json
import argparse
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
import asyncio
import re
from prompts import PromptTemplates

load_dotenv()

llm = ChatOpenAI(
    model=os.getenv("KIMI_MODEL"),
    api_key=os.getenv("OPENAI_API_KEY"),
    base_url=os.getenv("OPENAI_KIMI_URL"),
    temperature=0.7,
    max_tokens=2000
)

def load_id_list() -> str:
    """Load identity list from JailBench-id_en.json and format as text sequence"""
    try:
        with open("../JailBench-id_en.json", 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        id_list = []
        for idx, item in enumerate(data, 1):
            result = item.get('result', {})
            name = result.get('name', '')
            characteristic = result.get('characteristic', '')
            if name and characteristic:
                id_list.append(f"{idx}. {name}: {characteristic}")
        return "\n".join(id_list)
    except Exception :
        return ""

async def id_extract(harmful_text: str) -> tuple:
    """Extract the most suitable identity based on harmful text"""
    id_list = load_id_list()
    
    system_prompt = PromptTemplates.get_id_extract_system_prompt(id_list)
    user_prompt = f"User input: {harmful_text}"
    
    try:
        messages = [
            SystemMessage(content=system_prompt),
            HumanMessage(content=user_prompt)
        ]
        
        response = await llm.ainvoke(messages)
        result_text = response.content.strip()
        
        # Clean and extract JSON
        json_match = re.search(r'\{.*?\}', result_text, re.DOTALL)
        if json_match:
            json_str = json_match.group()
            try:
                parsed_json = json.loads(json_str)
                id_name = parsed_json.get('id_name', '')
                
                # Find corresponding characteristic from original data
                with open("../JailBench-id_en.json", 'r', encoding='utf-8') as f:
                    data = json.load(f)
                
                characteristic = ""
                for item in data:
                    result = item.get('result', {})
                    if result.get('name') == id_name:
                        characteristic = result.get('characteristic', '')
                        break
                
                return (id_name, characteristic)
                
            except json.JSONDecodeError :
                return ("", "")
        else:
            return ("", "")
            
    except Exception :
        return ("", "")

async def id_generate(identity_name: str, characteristic: str, harmful_text: str) -> str:
    """Generate disguised prompt based on identity, characteristics and harmful text"""
    system_prompt = PromptTemplates.get_id_generate_system_prompt(identity_name, characteristic, harmful_text)

    try:
        messages = [
            SystemMessage(content=system_prompt)
        ]
        
        response = await llm.ainvoke(messages)
        generated_text = response.content.strip()
        
        return generated_text
        
    except Exception :
        return "Generation failed"

async def generate_identity(harmful_text: str) -> dict:
    """Generate identity information based on harmful text"""
    prompt = PromptTemplates.get_generate_identity_prompt(harmful_text)
    
    try:
        messages = [
            SystemMessage(content="You are a professional character analysis expert who can accurately identify identity types most likely to execute specific behaviors."),
            HumanMessage(content=prompt)
        ]
        
        response = await llm.ainvoke(messages)
        response_text = response.content.strip()
        
        # Try to parse JSON
        try:
            # Extract JSON part (if contains other text)
            start_idx = response_text.find('{')
            end_idx = response_text.rfind('}') + 1
            if start_idx != -1 and end_idx != 0:
                json_str = response_text[start_idx:end_idx]
                identity_data = json.loads(json_str)
                return identity_data
            else:
                return {"name": "unknown", "characteristic": "unknown"}
        except json.JSONDecodeError :
            return {"name": "unknown", "characteristic": "unknown"}
            
    except Exception :
        return {"name": "unknown", "characteristic": "unknown"}

async def process_harmful_text(harmful_text: str) -> dict:
    """Main external function: input harmful_text, return generated identity and text"""
    try:
        # Step 1: Extract identity
        identity_name, characteristic = await id_extract(harmful_text)
        
        if not identity_name or not characteristic:
            # If extraction fails, generate new identity
            identity_data = await generate_identity(harmful_text)
            identity_name = identity_data.get('name', '')
            characteristic = identity_data.get('characteristic', '')
        
        
        # Step 2: Generate disguised prompt
        result = await id_generate(identity_name, characteristic, harmful_text)
        
        return {
            "name": identity_name,
            "characteristic": characteristic,
            "text": harmful_text,
            "generated_prompt": result
        }
        
    except Exception as e:
        return {
            "name": "error",
            "characteristic": "error",
            "text": harmful_text,
            "generated_prompt": f"Processing failed: {str(e)}"
        }

async def process_json_file(dirname: str):
    """Process JSON file and generate identity data"""
    # Construct file paths
    input_file = f"../data/processed_questions/{dirname}.json"
    output_file = f"../id_en_{dirname}.json"
    
    # Check if input file exists
    if not os.path.exists(input_file):
        return
    
    
    with open(input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    identities = []
    
    # Sort data by key numbers
    sorted_items = sorted(data.items(), key=lambda x: int(x[0]))
    
    for index, (key, item) in enumerate(sorted_items):
        changed_question = item.get("Changed Question", "")
        if not changed_question:
            continue
            
        
        # Generate identity data
        identity_data = await process_harmful_text(changed_question)
        identities.append(identity_data)
        
        # Add small delay to avoid API limits
        await asyncio.sleep(0.5)
    
    # Save results
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(identities, f, indent=2, ensure_ascii=False)
    
    print(f"Identity data saved to: {output_file}")

def main():
    """Main function"""
    parser = argparse.ArgumentParser(description='Generate identity data from processed questions')
    parser.add_argument('--dirname', required=True, 
                       help='Directory name to process (e.g., 04-Physical_Harm)')
    
    args = parser.parse_args()
    
    try:
        asyncio.run(process_json_file(args.dirname))
    except Exception as e:
        raise

if __name__ == "__main__":
    main()
