import numpy as np
import json
import os
import asyncio
from tqdm.asyncio import tqdm_asyncio
import argparse
from langchain.schema import SystemMessage
from langchain_openai import ChatOpenAI
import torch
from transformers import AutoTokenizer, AutoModel
from typing import Dict, List, Any, Literal
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from torch.cuda.amp import autocast  
import tiktoken

RETRIEVAL_METHODS = ["feature"]

DEFAULT_API_TIMEOUT = 30


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    gpu_info = torch.cuda.get_device_properties(0)
    total_memory_gb = gpu_info.total_memory / (1024**3)
    
    if total_memory_gb > 40:  
        encoding_semaphore_value = 32768  
    elif total_memory_gb > 20: 
        encoding_semaphore_value = 2048
    else:  
        encoding_semaphore_value = 1024
else:
    encoding_semaphore_value = 256  

api_semaphore = asyncio.Semaphore(256)
encoding_semaphore = asyncio.Semaphore(encoding_semaphore_value) 
print(f"Encoding semaphore set to {encoding_semaphore_value} based on {torch.cuda.get_device_properties(0).name if torch.cuda.is_available() else 'CPU'}")

@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=2, max=10),
    retry=retry_if_exception_type((asyncio.TimeoutError, ValueError))
)
async def api_call_with_retry(llm, system_message, response_field=None, allowed_values=None, timeout=DEFAULT_API_TIMEOUT):
    try:
        response = await asyncio.wait_for(
            llm.agenerate([[system_message]], response_format={"type": "json_object"}),
            timeout=timeout
        )
        
        if response_field and allowed_values:
            try:
                result = response.generations[0][0].text.strip()
                result_json = json.loads(result)
                field_value = result_json.get(response_field)

                field_value = str(field_value).strip() 
                
                if field_value not in allowed_values:
                    print(f"Invalid {response_field} '{field_value}'. Retrying with explicit instruction...")
                    modified_content = system_message.content + f"\n\nIMPORTANT: You MUST select a {response_field} from this exact list: {', '.join(allowed_values)}. Your previous prediction '{field_value}' was not in the allowed list."
                    system_message = SystemMessage(content=modified_content)
                    raise ValueError(f"Invalid {response_field}: {field_value}")
            except (json.JSONDecodeError, AttributeError, KeyError) as e:
                print(f"Warning: Could not validate response field: {e}")
                pass
            
        return response
        
    except asyncio.TimeoutError:
        print(f"API call timed out after {timeout}s, retrying...")
        raise
    except ValueError as e:
        print(f"Invalid value error: {e}, retrying...")
        raise
    except Exception as e:
        print(f"API call failed with error: {e}")
        raise

def setup_environment(config_path="config/api_info.json"):
    """Setup API keys and environment"""
    personal_info = json.load(open(config_path, "r"))
    os.environ["OPENAI_API_KEY"] = personal_info["api_key"]
    os.environ["OPENAI_ORGANIZATION"] = personal_info["org_id"]

def load_prompt(file_path):
    """Load prompt template from file"""
    with open(file_path, "r") as f:
        return f.read()

def load_data(file_path):
    """Load JSON data from file"""
    with open(file_path, "r") as f:
        return json.load(f)

def save_data(data, file_path):
    """Save data to JSON file"""
    with open(file_path, "w") as f:
        json.dump(data, f, indent=2)


def preprocess_data(test_data_path, personalized_reasoning_path):
    """
    Load test and personalized reasoning data
    """
    print(f"Loading test data from {test_data_path}")
    test_data = load_data(test_data_path)
    
    print(f"Loading personalized reasoning data from {personalized_reasoning_path}")
    personalized_reasoning = load_data(personalized_reasoning_path)
    
    return test_data, personalized_reasoning


def extract_factors_from_item(item):
    """Extract list of factors from an item"""
    factors = []
    
    for feature in item["feature"]:
        for factor in feature["factor"]:
            if factor and factor != "Unclassified":
                factors.append(factor)
    
    return factors

def extract_features_text(item):
    texts = []
    for feature in item["feature"]:
        if "feature_name" in feature:
            feature_text = feature["feature_name"]
            
            if "factor" in feature:
                factor_text = ", ".join(feature["factor"]) if isinstance(feature["factor"], list) else str(feature["factor"])
                feature_text += f" [factors: {factor_text}]"
                
            if "context" in feature:
                feature_text += f" [context: {feature['context']}]"
                          
            texts.append(feature_text)
    
    return " ".join(texts)

def batchify(texts, batch_size=8):
    """Split a list of texts into batches for parallel processing"""
    return [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]

async def _encode_single_batch(texts, tokenizer, model):
    """Encode a single batch of texts using the model"""
    async with encoding_semaphore:
        inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device)
        
        with torch.no_grad(), autocast(): 
            outputs = model(**inputs)
            
        attention_mask = inputs['attention_mask']
        last_hidden = outputs.last_hidden_state
        embeddings = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.)
        sum_embeddings = embeddings.sum(dim=1)
        sum_mask = attention_mask.sum(dim=1, keepdim=True)
        mean_embeddings = sum_embeddings / sum_mask
        
        return mean_embeddings.cpu().numpy()

async def encode_text_batches(texts, tokenizer, model, batch_size=32):
    if not texts:
        return np.array([])
    
    if torch.cuda.is_available():
        gpu_info = torch.cuda.get_device_properties(0)
        total_memory_gb = gpu_info.total_memory / (1024**3)
        
        adjusted_batch_size = 32768  
        
        print(f"Encoding texts with batch size {adjusted_batch_size} on {total_memory_gb:.1f}GB GPU")
    else:
        adjusted_batch_size = batch_size
        
    batches = batchify(texts, adjusted_batch_size)
    
    batch_results = await asyncio.gather(*[_encode_single_batch(batch, tokenizer, model) for batch in batches])
    
    if batch_results:
        return np.vstack(batch_results)
    else:
        return np.array([])

async def retrieve_by_feature(user_id, test_item, user_samples, tokenizer, model, top_k=3, args=None):
    if not user_samples:
        print(f"No samples found for user {user_id}")
        return []
    
    test_features_text = extract_features_text(
        test_item
    )
    
    if not test_features_text:
        print(f"No feature text for test item {test_item.get('item_id')}")
        return []
    
    candidate_samples = []
    candidate_texts = []
    
    for sample in user_samples:
        features_text = extract_features_text(
            sample
        )
        if features_text:
            candidate_texts.append(features_text)
            candidate_samples.append(sample)
    
    if not candidate_samples:
        return []
    
    texts_to_encode = [test_features_text] + candidate_texts
    all_embeddings = await encode_text_batches(texts_to_encode, tokenizer, model)
    
    if len(all_embeddings) < 2: 
        return candidate_samples[:top_k]

    test_embedding = all_embeddings[0:1] 
    candidate_embeddings = all_embeddings[1:]
    
    similarities = np.dot(test_embedding, candidate_embeddings.T)[0]
    
    top_indices = np.argsort(similarities)[-top_k:][::-1]
    
    return [candidate_samples[idx] for idx in top_indices]

async def retrieve_few_shot_examples(user_id, test_item, user_samples, tokenizer, model, top_k=3, method="hybrid", args=None):
    if method == "feature":
        return await retrieve_by_feature(user_id, test_item, user_samples, tokenizer, model, top_k, args)
    else:
        print(f"Invalid retrieval method: {method}")
        return []


async def bbox_inference(llm, test_item, similar_samples, user_factors, args=None):
    async with api_semaphore:

        examples = []
        for i, sample in enumerate(similar_samples):
            example = f""
            example = f"Example {i+1}:\n"
            
            example += f"question: {sample.get('question', '')}\n"
            example += f"options: {sample.get('options', [])}\n"
            
            feature_data = []
            for feature in sample["feature"]:
                feature_info = {}
                if "feature_name" in feature:
                    feature_info["feature_name"] = feature["feature_name"]
                if "factor" in feature:
                    feature_info["factor"] = feature["factor"]
                if "context" in feature:
                    feature_info["context"] = feature["context"]
                feature_data.append(feature_info)
            
            example += f"user preference factors and statistics: {sample['factors']}\n"
            
            example += f"question features-factors: {feature_data}\n"

            if "reasoning" in sample:
                example += f"reasoning: {sample['reasoning']}\n"
                
            answer = sample.get("answer")
            example += f"predicted_answer: {answer}"
            examples.append(example)
        
        formatted_examples = "\n\n" + "-"*50 + "\n\n".join(examples) if examples else ""
        
        test_features = []
        for feature in test_item["feature"]:
            feature_info = {}
            if "feature_name" in feature:
                feature_info["feature_name"] = feature["feature_name"]
            if "factor" in feature:
                feature_info["factor"] = feature["factor"]
            if "context" in feature:
                feature_info["context"] = feature["context"]
            test_features.append(feature_info)
        
        question = test_item["question"]
        options = test_item.get("options", [])
        
        prompt = load_prompt("prompt/22_bbox_inference.txt").format(
            features=json.dumps(test_features),
            question=question,
            options=json.dumps(options),
            user_factors=json.dumps(user_factors),
            similar_samples_with_feedbacks=formatted_examples,
        )
        
        system_message = SystemMessage(content=prompt)
        try:
            response = await api_call_with_retry(llm, system_message, response_field="predicted_answer")
            
            result = response.generations[0][0].text.strip()
            result_json = json.loads(result)
            
            predicted_answer = result_json.get("predicted_answer")
            reasoning = result_json.get("reasoning")

            predicted_answer = str(predicted_answer)
            
            return {
                "item_id": test_item["item_id"],
                "question": question,
                "options": options,
                "actual_answer": test_item["answer"],
                "predicted_answer": predicted_answer,
                "reasoning": reasoning,
                "input_prompt": prompt
            }
        
        except Exception as e:
            print(f"Error in inference: {str(e)}")
            return {
                "item_id": test_item["item_id"],
                "question": question,
                "options": options,
                "actual_answer": test_item["answer"],
                "predicted_answer": "error",
                "reasoning": f"Error: {str(e)}"
            }


def calculate_metrics(predictions, references):
    if not predictions or not references:
        return {"acc": 0}
    
    correct = sum(p == r for p, r in zip(predictions, references))
    accuracy = correct / len(predictions) if predictions else 0
    
    return {
        "acc": accuracy
    }


async def compute_embeddings_batch(texts, tokenizer, model, batch_size):
    """Pre-compute embeddings for a list of texts"""
    if not texts:
        return np.array([])
    
    try:
        return await encode_text_batches(texts, tokenizer, model, batch_size)
    except Exception as e:
        print(f"Error computing embeddings: {str(e)}")
        return np.array([])

async def process_test_item(user_id, test_item, user_samples, user_factors, tokenizer, model, top_k, llm, retrieval_method, args=None):
    """Process a single test item for a user"""
    try:
        similar_samples = await retrieve_few_shot_examples(
            user_id, test_item, user_samples, tokenizer, model, 
            top_k=top_k, method=retrieval_method, args=args
        )

        result = await bbox_inference(llm, test_item, similar_samples, user_factors, args)
        result["user_id"] = user_id
        result["retrieval_method"] = retrieval_method
        
        print(f"User {user_id}, Item {test_item['item_id']}: predicted={result['predicted_answer']}, actual={test_item['answer']} (method: {retrieval_method})")
        return result
    except Exception as e:
        print(f"Error processing test item {test_item['item_id']} for user {user_id}: {str(e)}")
        return {
            "user_id": user_id,
            "item_id": test_item["item_id"],
            "question": test_item.get("question", ""),
            "options": test_item.get("options", []),
            "actual_answer": test_item.get("answer", "unknown"),
            "predicted_answer": "error",
            "reasoning": f"Error: {str(e)}",
            "retrieval_method": retrieval_method
        }

async def process_user(user, personalized_reasoning_by_user, user_factors_data, tokenizer, model, top_k, llm, batch_size, semaphore, retrieval_method, args=None):
    """Process all test items for a single user"""
    async with semaphore:  
        user_id = user.get("user_id")
        user_results = []
        
        user_samples = personalized_reasoning_by_user.get(user_id, [])
        user_factors = user_factors_data.get(str(user_id))
        
        tasks = []
        for test_item in user.get("profile"):
            tasks.append(process_test_item(
                user_id, test_item, user_samples, user_factors, 
                tokenizer, model, top_k, llm, retrieval_method, args
            ))
        
        if tasks:
            item_results = await asyncio.gather(*tasks, return_exceptions=True)
            for result in item_results:
                if isinstance(result, Exception):
                    print(f"Error processing item for user {user_id}: {str(result)}")
                    continue
                user_results.append(result)
        
        return user_results


async def main():
    """Main execution function"""
    parser = argparse.ArgumentParser(description="Run few-shot retrieval and inference on test data")
    parser.add_argument("--model_name", type=str, default="gpt-4o-mini", help="Model name for inference")
    parser.add_argument("--test_data_path", type=str, default="result/goqa_test_feature.json", help="Path to test data")
    parser.add_argument("--personalized_reasoning_path", type=str, default="result/goqa_personalized_reasoning.json", 
                       help="Path to personalized reasoning data")
    parser.add_argument("--factor_path", type=str, default="result/goqa_factor.json", 
                       help="Path to user factors data")
    parser.add_argument("--output_path", type=str, default="result/goqa_bbox_inference.json", help="Output path")
    parser.add_argument("--num_users", type=int, default=0, help="Number of users to process (0 for all)")
    parser.add_argument("--top_k", type=int, default=3, help="Number of similar examples to use")
    parser.add_argument("--max_parallel_users", type=int, default=128, help="Maximum number of users to process in parallel")
    parser.add_argument("--retrieval_method", type=str, default="feature", 
                       choices=RETRIEVAL_METHODS, help="Method to retrieve few-shot examples")
    
    args = parser.parse_args()
    
    setup_environment()
    
    test_data, personalized_reasoning = preprocess_data(args.test_data_path, args.personalized_reasoning_path)
    
    print(f"Loading user factors from {args.factor_path}")
    user_factors_data = load_data(args.factor_path)
    
    llm = ChatOpenAI(temperature=0.0, model_name=args.model_name)
    
    print("Loading Contriever model")
    tokenizer = AutoTokenizer.from_pretrained("facebook/contriever")
    model = AutoModel.from_pretrained("facebook/contriever", torch_dtype=torch.float16).to(device)
    print(f"Successfully loaded Contriever model on {device} with dtype={model.dtype}")
    
    if args.num_users > 0:
        test_data = test_data[:args.num_users]
    
    personalized_reasoning_by_user = {p_user["user_id"]: p_user["profile"] for p_user in personalized_reasoning}
    
    batch_size = 32768 

    user_semaphore = asyncio.Semaphore(args.max_parallel_users)
    
    print(f"Using retrieval method: {args.retrieval_method}")
    
    user_tasks = []
    total_items = 0
    for user in test_data:
        user_tasks.append(process_user(
            user, personalized_reasoning_by_user, user_factors_data,
            tokenizer, model, args.top_k, llm, batch_size, user_semaphore, args.retrieval_method, args
        ))
        total_items += len(user.get("profile"))
    
    print(f"Processing {len(user_tasks)} users with up to {args.max_parallel_users} in parallel")
    print(f"Total test items to process: {total_items}")
    
    user_results = await tqdm_asyncio.gather(*user_tasks, desc="Processing users")
    
    all_results = []
    for result in user_results:
        if isinstance(result, Exception):
            print(f"Error processing user: {str(result)}")
            continue
        all_results.extend(result)
    
    if all_results:
        valid_results = [r for r in all_results if r["predicted_answer"] != "error"]
        predictions = [r["predicted_answer"] for r in valid_results]
        references = [r["actual_answer"] for r in valid_results]
        
        metrics = calculate_metrics(predictions, references)
        print(f"\nOverall Metrics: Accuracy = {metrics['acc']:.4f}")
        print(f"Total samples: {len(all_results)}, Valid samples: {len(valid_results)}")
    else:
        metrics = {"acc": 0}
        print("No results to calculate metrics")
    
    output = {
        "metrics": metrics,
        "retrieval_method": args.retrieval_method,
        "results": all_results
    }
    
    output_path = args.output_path
    if args.output_path.endswith(".json"):
        base_path, ext = args.output_path.rsplit(".", 1)
        output_path = f"{base_path}_{args.retrieval_method}.json"
    
    print(f"Saving results to {output_path}")
    save_data(output, output_path)

if __name__ == "__main__":
    asyncio.run(main())
