import pdb
import numpy as np
from tqdm import tqdm
from langchain_community.llms import OpenAIChat, OpenAI
from langchain_community.chat_models import ChatOpenAI, openai
import json
import argparse
import asyncio
from tqdm.asyncio import tqdm_asyncio
import random
import re
import os
from collections import Counter, defaultdict
import tiktoken
from langchain.schema import (
    AIMessage,
    HumanMessage,
    SystemMessage
)
from pathlib import Path
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

DEFAULT_API_TIMEOUT = 30

BASE_DIR = Path(os.path.dirname(os.path.abspath(__file__))).parent

@retry(
    stop=stop_after_attempt(10),
    wait=wait_exponential(multiplier=1, min=2, max=10),
    retry=retry_if_exception_type(asyncio.TimeoutError)
)
async def api_call_with_retry(llm, system_message, timeout=DEFAULT_API_TIMEOUT):
    try:
        response = await asyncio.wait_for(
            llm.agenerate([[system_message]], response_format={"type": "json_object"}),
            timeout=timeout
        )

        return response
    except asyncio.TimeoutError:
        print(f"API call timed out after {timeout}s, retrying...")
        raise 
    except Exception as e:
        print(f"API call failed with error: {e}")
        raise

personal_info = json.load(open(os.path.join(BASE_DIR, "config/api_info.json"), "r"))
os.environ["OPENAI_API_KEY"] = personal_info["api_key"]
os.environ["OPENAI_ORGANIZATION"] = personal_info["org_id"]

semaphore = asyncio.Semaphore(256)

def load_prompt(file_path):
    abs_path = os.path.join(BASE_DIR, file_path)
    with open(abs_path, 'r', encoding='utf-8') as f:
        return f.read()

def load_json_file(filepath):
    abs_path = os.path.join(BASE_DIR, filepath) if not os.path.isabs(filepath) else filepath
    print(f"Loading JSON from {abs_path}")
    with open(abs_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def load_feature_data(feature_path):
    feature_data = load_json_file(feature_path)
    feature_map = {}
    
    for user_data in feature_data:
        user_id = user_data.get("user_id")
        profile = user_data.get("profile")
        
        for profile_item in profile:
            if "item_id" in profile_item:
                item_id = profile_item["item_id"]
                if item_id not in feature_map:
                    feature_map[item_id] = {
                        "user_id": user_id,
                        "options": profile_item.get("options"),
                        "answer": profile_item.get("answer"),
                        "question": profile_item.get("question"),
                        "features": profile_item.get("feature")
                    }
    
    return feature_map

async def analyze_sample_features(llm, options, answer, features, feature_factor_map):
    if not features or not options:
        return {
            "influenced_features": [],
            "influenced_factors": set(),
            "evaluation_map": {}  
        }
    
    formatted_features = []
    for i, feature in enumerate(features):
        feature_text = ""
        try:
            if isinstance(feature, dict):
                if "feature_name" in feature:
                    feature_text += f"Feature Name: {feature['feature_name']}\n"
                if "context" in feature:
                    feature_text += f"Context: {feature['context']}\n"
                
                if not feature_text:
                    feature_text = json.dumps(feature)
            else:
                feature_text = str(feature)
        except Exception as e:
            print(f"Error formatting feature: {e}")
            feature_text = str(feature)
        
        formatted_features.append(f"Feature {i}:\n{feature_text}")
    
    all_features_text = "\n\n".join(formatted_features)
    
    async with semaphore:
        try:
            system_prompt = load_prompt("prompt/14_factor_statistics.txt").format(
                options=options, answer=answer, features=all_features_text)
            system_message = SystemMessage(content=system_prompt)
            
            response = await api_call_with_retry(llm, system_message)
            result = response.generations[0][0].text.strip()
            
            try:
                json_result = json.loads(result)
                influences = json_result.get("influences", [])
                
                influenced_features = []
                influenced_factors = set()
                evaluation_map = {}  
                
                for influence in influences:
                    feature_index = influence.get("feature_index")
                    influenced = influence.get("influenced")
                    evaluation = influence.get("evaluation", "neu")  
                    
                    if influenced and feature_index is not None and feature_index < len(features):
                        influenced_features.append(feature_index)
                        evaluation_map[feature_index] = evaluation
                        
                        if feature_index in feature_factor_map:
                            factor_name = feature_factor_map[feature_index]
                            influenced_factors.add(factor_name)
                
                return {
                    "influenced_features": influenced_features,
                    "influenced_factors": influenced_factors,
                    "evaluation_map": evaluation_map
                }
            except Exception as e:
                print(f"Error parsing API response: {e}")
                return {
                    "influenced_features": [],
                    "influenced_factors": set(),
                    "evaluation_map": {}
                }
        except Exception as e:
            print(f"Error in API call: {e}")
            return {
                "influenced_features": [],
                "influenced_factors": set(),
                "evaluation_map": {}
            }

def collect_sample_tasks(llm, factor_data, feature_map):
    sample_tasks = []
    user_item_factors = {} 
    task_count = 0
    
    for user_data in factor_data:
        user_id = user_data.get("user_id")
        if not user_id:
            continue
            
        factorization = user_data.get("factorization", {})
        factor_features = factorization.get("factors", {})
        
        processed_items = set()
        
        for factor_name, factor_items in factor_features.items():
            for item in factor_items:
                item_id = item.get("item_id")
                if not item_id or not item_id in feature_map:
                    continue
                
                if (user_id, item_id) in processed_items:
                    continue
                
                processed_items.add((user_id, item_id))
                
                item_data = feature_map[item_id]
                options = item_data.get("options", "")
                answer = item_data.get("answer", "")
                all_features = item_data.get("features", [])
                
                feature_factor_map = {}
                expected_factors = defaultdict(set)
                
                for i, feature in enumerate(all_features):
                    feature_factors = feature.get("factor", [])
                    for f_name in feature_factors:
                        feature_factor_map[i] = f_name
                        expected_factors[f_name].add(i)
                
                user_item_factors[(user_id, item_id)] = dict(expected_factors)
                
                task = analyze_sample_features(llm, options, answer, all_features, feature_factor_map)
                sample_tasks.append((user_id, item_id, task))
                task_count += 1
    
    print(f"Created {task_count} sample analysis tasks")
    return sample_tasks, user_item_factors

def organize_sample_results(sample_results, user_item_factors):
    user_factor_stats = defaultdict(lambda: defaultdict(lambda: {
        "total": 0, 
        "influenced": 0,
        "pos": 0,  
        "neu": 0,  
        "neg": 0   
    }))
    
    for user_id, item_id, result in sample_results:
        influenced_factors = result.get("influenced_factors", set())
        evaluation_map = result.get("evaluation_map", {})
        
        expected_factors = user_item_factors.get((user_id, item_id), {})
        
        for factor_name, feature_indices in expected_factors.items():
            user_factor_stats[user_id][factor_name]["total"] += 1
            
            factor_influenced = factor_name in influenced_factors
            if factor_influenced:
                user_factor_stats[user_id][factor_name]["influenced"] += 1
                
                for feature_index in feature_indices:
                    if feature_index in evaluation_map:
                        evaluation = evaluation_map[feature_index]
                        if evaluation in ["pos", "neu", "neg"]:
                            user_factor_stats[user_id][factor_name][evaluation] += 1
    
    statistics = {}
    for user_id, factors in user_factor_stats.items():
        user_stats = {
            "user_id": user_id,
            "factors": {}
        }
        
        for factor_name, stats in factors.items():
            total = stats["total"]
            influenced = stats["influenced"]
            pos_count = stats["pos"]
            neu_count = stats["neu"]
            neg_count = stats["neg"]
            
            influence_percentage = 0
            if total > 0:
                influence_percentage = (influenced / total) * 100
            
            pos_percentage = 0
            neu_percentage = 0
            neg_percentage = 0
            
            total_count = pos_count + neu_count + neg_count

            if influenced > 0:
                pos_percentage = (pos_count / total_count) * 100
                neu_percentage = (neu_count / total_count) * 100
                neg_percentage = (neg_count / total_count) * 100
            
            influence_str = f"{influenced}/{total}({influence_percentage:.1f}%)"

            pos_str = f"pos: {pos_count}/{total_count}({pos_percentage:.1f}%)"
            neu_str = f"neu: {neu_count}/{total_count}({neu_percentage:.1f}%)"
            neg_str = f"neg: {neg_count}/{total_count}({neg_percentage:.1f}%)"
            
            evaluation_distribution = {
                "positive": pos_str,
                "neutral": neu_str,
                "negative": neg_str
            }
            
            user_stats["factors"][factor_name] = {
                "count": total,
                "directly_influenced": influence_str,
                "evaluation_distribution": evaluation_distribution
            }
        
        statistics[str(user_id)] = user_stats
    
    return statistics

async def calculate_factor_statistics(args):
    llm = ChatOpenAI(temperature=0, model_name=args.model_name)
    
    print(f"Loading factor data from {args.factor_path}")
    factor_data = load_json_file(args.factor_path)
    
    print(f"Loading feature data from {args.feature_path}")
    feature_map = load_feature_data(args.feature_path)
    print(f"Created mapping for {len(feature_map)} items")
    
    if args.num_users > 0:
        factor_data = factor_data[:args.num_users]
        print(f"Processing {args.num_users} users")
    else:
        print(f"Processing all {len(factor_data)} users")
    
    sample_tasks, user_item_factors = collect_sample_tasks(llm, factor_data, feature_map)
    
    processed_tasks = []
    for user_id, item_id, task in sample_tasks:
        processed_tasks.append(task)
    
    if processed_tasks:
        all_results = await tqdm_asyncio.gather(*processed_tasks, desc="Analyzing samples")
        
        sample_results = []
        for i, result in enumerate(all_results):
            user_id, item_id, _ = sample_tasks[i]
            sample_results.append((user_id, item_id, result))
    else:
        print("No tasks to process!")
        sample_results = []
    
    statistics = organize_sample_results(sample_results, user_item_factors)
    
    return statistics

async def main():
    statistics = await calculate_factor_statistics(args)
    
    with open(args.output_path, 'w', encoding='utf-8') as f:
        json.dump(statistics, f, indent=2, ensure_ascii=False)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_users", type=int, default=0)
    parser.add_argument("--factor_path", type=str, default="result/goqa_only_factor.json")
    parser.add_argument("--feature_path", type=str, default="result/goqa_feature.json")
    parser.add_argument("--output_path", type=str, default="result/goqa_factor.json")
    parser.add_argument("--model_name", type=str, default="gpt-4o-mini")
    args = parser.parse_args()
    
    asyncio.run(main())
