import pdb
import numpy as np
from tqdm import tqdm
from langchain_community.llms import OpenAIChat, OpenAI
from langchain_community.chat_models import ChatOpenAI, openai
import json
import argparse
import asyncio
from tqdm.asyncio import tqdm_asyncio
import random
import re
import os
import tiktoken
from langchain.schema import (
    AIMessage,
    HumanMessage,
    SystemMessage
)
from pathlib import Path
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

DEFAULT_API_TIMEOUT = 30


@retry(
    stop=stop_after_attempt(10),
    wait=wait_exponential(multiplier=1, min=2, max=10),
    retry=retry_if_exception_type(asyncio.TimeoutError)
)
async def api_call_with_retry(llm, system_message, timeout=DEFAULT_API_TIMEOUT):
    try:
        return await asyncio.wait_for(
            llm.agenerate([[system_message]], response_format={"type": "json_object"}),
            timeout=timeout
        )
    except asyncio.TimeoutError:
        print(f"API call timed out after {timeout}s, retrying...")
        raise  
    except Exception as e:
        print(f"API call failed with error: {e}")
        raise

personal_info = json.load(open("config/api_info.json", "r"))
os.environ["OPENAI_API_KEY"] = personal_info["api_key"]
os.environ["OPENAI_ORGANIZATION"] = personal_info["org_id"]

semaphore = asyncio.Semaphore(256)

def load_prompt(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read()


async def async_extract_feature(llm, user_id, item_id, question, answer, options):
    async with semaphore:
        system_prompt = load_prompt("prompt/11_extract_feature.txt").format(
            question=question, options=options)
        system_message = SystemMessage(content=system_prompt)
        
        response = await api_call_with_retry(llm, system_message, timeout=30)
        
        result = response.generations[0][0].text.strip()

        features = json.loads(result)["features"]
        
        for feature in features:
            feature["user_id"] = user_id
            feature["item_id"] = item_id

        return {
            "user_id": user_id,
            "item_id": item_id,
            "question": question,
            "options": options,
            "answer": answer,
            "feature": features
        }


async def extract_feature(args, input_data):
    llm = ChatOpenAI(temperature=0, model_name=args.model_name)
    tasks = []

    if args.num_users > 0:
        input_data = input_data[:args.num_users]
    
    for sample in input_data:
        user_id = sample['user_id']
        profile = sample['profile']

        for item in profile:
            item_id = item['id']
            question = item['question']
            answer = item['answer']
            options = item['options']

            tasks.append(async_extract_feature(llm, user_id, item_id, question, answer, options))

    sample_with_feature = await tqdm_asyncio.gather(*tasks)

    result = []
    for sample in sample_with_feature:
        user_id = sample['user_id']
        
        user_entry = None
        for entry in result:
            if entry['user_id'] == user_id:
                user_entry = entry
                break
                
        if user_entry is None:
            user_entry = {'user_id': user_id, 'profile': []}
            result.append(user_entry)
            
        user_entry['profile'].append(sample)
    
    return result

async def main():
    with open(args.input_path, 'r', encoding='utf-8') as f:
        input_data = json.load(f)

    extracted_features = await extract_feature(args, input_data)

    with open(args.output_path, 'w', encoding='utf-8') as f:
        json.dump(extracted_features, f, indent=2, ensure_ascii=False)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_users", type=int, default=0)
    parser.add_argument("--input_path", type=str, default="dataset/goqa_train.json")
    parser.add_argument("--output_path", type=str, default="result/goqa_only_feature.json")
    parser.add_argument("--model_name", type=str, default="gpt-4o-mini")
    args = parser.parse_args()
    
    asyncio.run(main())
