import pdb
import numpy as np
import json
import os
import asyncio
from tqdm.asyncio import tqdm_asyncio
import argparse
import tiktoken
from langchain.schema import SystemMessage
from langchain_community.chat_models import ChatOpenAI
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

DEFAULT_API_TIMEOUT = 30


personal_info = json.load(open("config/api_info.json", "r"))
os.environ["OPENAI_API_KEY"] = personal_info["api_key"]
os.environ["OPENAI_ORGANIZATION"] = personal_info["org_id"]

semaphore = asyncio.Semaphore(256)

@retry(
    stop=stop_after_attempt(10),
    wait=wait_exponential(multiplier=1, min=2, max=10),
    retry=retry_if_exception_type(asyncio.TimeoutError)
)
async def api_call_with_retry(llm, system_message, timeout=DEFAULT_API_TIMEOUT):
    try:
        response = await asyncio.wait_for(
            llm.agenerate([[system_message]], response_format={"type": "json_object"}),
            timeout=timeout
        )

        return response
    except asyncio.TimeoutError:
        print(f"API call timed out after {timeout}s, retrying...")
        raise  
    except Exception as e:
        print(f"API call failed with error: {e}")
        raise

def load_prompt(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read()

async def async_extract_feature(llm, user_id, item_id, question, options, answer):
    async with semaphore:
        system_prompt = load_prompt("prompt/21_extract_test_feature.txt").format(
            question=question, options=options)
        system_message = SystemMessage(content=system_prompt)
        
        response = await api_call_with_retry(llm, system_message, timeout=30)
        
        result = response.generations[0][0].text.strip()

        features = json.loads(result)["features"]
        
        for feature in features:
            feature["user_id"] = user_id
            feature["item_id"] = item_id

        return {
            "user_id": user_id,
            "item_id": item_id,
            "question": question,
            "options": options,
            "answer": answer,
            "feature": features
        }

async def async_assign_to_factor(llm, feature, factors):
    async with semaphore:
        formatted_factors = "\n".join([f"{i+1}. {desc}" for i, desc in enumerate(factors)])
        
        feature_summary = f"Feature: {feature.get('feature_name')}"
        if 'context' in feature:
            feature_summary += f" (Context: {feature['context']})"
        
        system_prompt = load_prompt("prompt/12_assign_to_factors.txt").format(
            feature=feature_summary,
            formatted_factors=formatted_factors
        )
        
        system_message = SystemMessage(content=system_prompt)
        response = await api_call_with_retry(llm, system_message, timeout=30)
        result = response.generations[0][0].text.strip()
        
        json_result = json.loads(result)
        assigned_factors = []
        
        if "assignments" in json_result:
            assignment_value = json_result["assignments"]
            
            if isinstance(assignment_value, list) and assignment_value:
                idx = assignment_value[0]
                if isinstance(idx, int) and 1 <= idx <= len(factors):
                    assigned_factors.append(factors[idx - 1])
            
            elif isinstance(assignment_value, str) and assignment_value not in ["0", "none", "None"]:
                first_assignment = assignment_value.replace(" ", "").split(",")[0]
                try:
                    idx = int(first_assignment) - 1
                    if 0 <= idx < len(factors):
                        assigned_factors.append(factors[idx])
                except ValueError:
                    pass
        
        return assigned_factors

async def process_test_data(args):
    with open(args.test_data_path, 'r', encoding='utf-8') as f:
        test_data = json.load(f)

    with open(args.factor_path, 'r', encoding='utf-8') as f:
        factor_data = json.load(f)
    
    llm = ChatOpenAI(temperature=0, model_name=args.model_name)
    
    if args.num_users > 0:
        test_data = test_data[:args.num_users]
    
    user_factors = {}
    for user_data in factor_data:
        user_id = user_data.get('user_id')
        if user_id and "factorization" in user_data and "selected_factors" in user_data["factorization"]:
            user_factors[user_id] = user_data["factorization"]["selected_factors"]
    
    print(f"Found factor data for {len(user_factors)} users")
    
    print("Preparing extraction tasks for all users...")
    
    all_extract_tasks = []
    user_sample_map = {}  

    task_to_user_item = []
    
    for user in test_data:
        user_id = user.get('user_id')
        user_sample_map[user_id] = []
        
        for item in user.get('profile', []):
            if 'question' in item and len(item['question'].strip()) > 0:
                item_id = item.get('item_id', item.get('id', 'unknown'))
                question = item.get('question')
                options = item.get('options')
                answer = item.get('answer')
                
                all_extract_tasks.append(async_extract_feature(llm, user_id, item_id, question, options, answer))
                task_to_user_item.append((user_id, item_id))
    
    print(f"Extracting features for {len(all_extract_tasks)} samples across {len(test_data)} users...")
    

    all_samples = await tqdm_asyncio.gather(*all_extract_tasks, desc="Extracting features")
    

    for idx, sample in enumerate(all_samples):
        user_id = task_to_user_item[idx][0]
        if user_id in user_sample_map:
            user_sample_map[user_id].append(sample)
    
    print("Preparing factor assignment tasks...")
    

    all_assignment_tasks = []
    feature_to_sample_map = [] 
    
    for user_id, samples in user_sample_map.items():
        
        user_specific_factors = user_factors.get(user_id, [])
        if not user_specific_factors:
            print(f"Warning: No factors found for user {user_id}, using global factors as fallback")
            for any_user_id, any_user_factors in user_factors.items():
                if any_user_factors:  
                    user_specific_factors = any_user_factors
                    print(f"Using factors from user {any_user_id} as fallback")
                    break
        
        for sample_idx, sample in enumerate(samples):
            for feature_idx, feature in enumerate(sample['feature']):
                all_assignment_tasks.append(async_assign_to_factor(llm, feature, user_specific_factors))
                feature_to_sample_map.append((user_id, sample_idx, feature_idx))
    
    print(f"Assigning factors for {len(all_assignment_tasks)} features using user-specific factors...")
    
    all_factor_assignments = await tqdm_asyncio.gather(*all_assignment_tasks, desc="Assigning factors")
    
    for task_idx, factors in enumerate(all_factor_assignments):
        user_id, sample_idx, feature_idx = feature_to_sample_map[task_idx]
        user_sample_map[user_id][sample_idx]['feature'][feature_idx]['factor'] = factors
    
    result = []
    for user in test_data:
        user_id = user.get('user_id')
        if user_id in user_sample_map:
            user_entry = {'user_id': user_id, 'profile': user_sample_map[user_id]}
            result.append(user_entry)
    
    return result

async def main():
    parser = argparse.ArgumentParser(description="Process test data to extract features and assign factors")
    parser.add_argument("--num_users", type=int, default=0, 
                        help="Number of users to process (0 for all)")
    parser.add_argument("--test_data_path", type=str, default="dataset/goqa_test.json",
                        help="Path to the test data file")
    parser.add_argument("--factor_path", type=str, default="result/goqa_only_factor.json",
                        help="Path to the factor definitions file")
    parser.add_argument("--output_path", type=str, default="result/goqa_test_feature.json",
                        help="Path to save the results")
    parser.add_argument("--model_name", type=str, default="gpt-4o-mini",
                        help="Name of the language model to use")
    
    args = parser.parse_args()
    
    try:
        import datetime
        start_time = datetime.datetime.now()
        
        result = await process_test_data(args)
        
        total_samples = sum(len(user.get('profile', [])) for user in result)
        total_features = sum(sum(len(item.get('feature', [])) 
                                 for item in user.get('profile', [])) 
                             for user in result)
        
        with open(args.output_path, 'w', encoding='utf-8') as f:
            json.dump(result, f, indent=2, ensure_ascii=False)
        
        print(f"Results saved to {args.output_path}")
        print(f"Processed {len(result)} users")
        print(f"Processed {total_samples} samples")
        print(f"Processed {total_features} features")
    
        
    except Exception as e:
        print(f"Error: {e}")

if __name__ == "__main__":
    asyncio.run(main())