import pdb
import numpy as np
from tqdm import tqdm
from langchain_community.llms import OpenAIChat, OpenAI
from langchain_community.chat_models import ChatOpenAI, openai
import json
import argparse
import asyncio
from tqdm.asyncio import tqdm_asyncio
import random
import re
import os
import tiktoken
from langchain.schema import (
    AIMessage,
    HumanMessage,
    SystemMessage
)
from pathlib import Path
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

DEFAULT_API_TIMEOUT = 30


@retry(
    stop=stop_after_attempt(10),
    wait=wait_exponential(multiplier=1, min=2, max=10),
    retry=retry_if_exception_type(asyncio.TimeoutError)
)
async def api_call_with_retry(llm, system_message, timeout=DEFAULT_API_TIMEOUT):
    try:
        return await asyncio.wait_for(
            llm.agenerate([[system_message]], response_format={"type": "json_object"}),
            timeout=timeout
        )
    except asyncio.TimeoutError:
        print(f"API call timed out after {timeout}s, retrying...")
        raise  
    except Exception as e:
        print(f"API call failed with error: {e}")
        raise

personal_info = json.load(open("config/api_info.json", "r"))
os.environ["OPENAI_API_KEY"] = personal_info["api_key"]
os.environ["OPENAI_ORGANIZATION"] = personal_info["org_id"]

semaphore = asyncio.Semaphore(256)

def load_prompt(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read()

async def generate_personalized_reasoning(llm, user_id, item_id, question, options, answer, features, factors):
    """Generate personalized reasoning for a survey response based on country factors"""
    async with semaphore:
        prompt = load_prompt("prompt/15_reasoning_personalization.txt").format(
            question=question,
            options=json.dumps(options, ensure_ascii=False),
            answer=answer,
            features=json.dumps(features, ensure_ascii=False),
            factors=json.dumps(factors, ensure_ascii=False)
        )
        
        system_message = SystemMessage(content=prompt)
        
        response = await api_call_with_retry(llm, system_message)
        
        result = response.generations[0][0].text.strip()
        result_json = json.loads(result)
        
        return {
            "user_id": user_id,
            "item_id": item_id,
            "question": question,
            "options": options,
            "answer": answer,
            "feature": features,
            "factors": factors,
            "reasoning": result_json["reasoning"]
        }

async def process_data(args, feature_data, factor_data):
    llm = ChatOpenAI(temperature=0.0,  model_name=args.model_name)
    tasks = []
    
    factor_mapping = {}
    
    for user_id_str in factor_data:
        if "factors" in factor_data[user_id_str]:
            factor_mapping[user_id_str] = factor_data[user_id_str]
    
    user_task_mapping = []
    
    for user_entry in feature_data:
        user_id = user_entry['user_id']
        user_id_str = str(user_id)
        profile = user_entry['profile']
        
        user_factors = factor_mapping.get(user_id_str, {})
        
        for item in profile:
            if 'question' not in item or 'answer' not in item:
                continue
                
            item_id = item['item_id']
            question = item['question']
            options = item.get('options', [])
            answer = item['answer']
            features = item.get('feature', [])
            
            task = generate_personalized_reasoning(
                llm, user_id, item_id, question, options, answer, features, user_factors
            )
            tasks.append(task)
            user_task_mapping.append(user_id)
    
    results = await tqdm_asyncio.gather(*tasks)
    
    organized_results = {}
    for i, result in enumerate(results):
        user_id = user_task_mapping[i]
        if user_id not in organized_results:
            organized_results[user_id] = {"user_id": user_id, "profile": []}
        
        item_result = {k: v for k, v in result.items() if k != 'user_id'}
        organized_results[user_id]["profile"].append(item_result)
    
    return list(organized_results.values())

async def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--feature_path", type=str, default="result/goqa_feature.json")
    parser.add_argument("--factor_path", type=str, default="result/goqa_factor.json")
    parser.add_argument("--output_path", type=str, default="result/goqa_personalized_reasoning.json")
    parser.add_argument("--model_name", type=str, default="gpt-4o-mini")
    args = parser.parse_args()
    
    with open(args.feature_path, 'r') as f:
        feature_data = json.load(f)
    
    with open(args.factor_path, 'r') as f:
        factor_data = json.load(f)
    
    results = await process_data(args, feature_data, factor_data)
    
    with open(args.output_path, 'w') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
     

if __name__ == "__main__":
    asyncio.run(main())