import os
import json
import random
import time
import gc
import asyncio
import aiohttp
from datasets import load_dataset
from tqdm import tqdm
import argparse
import sys
from openai import OpenAI, AsyncOpenAI
from typing import List, Dict, Any


def print_sample_triplet(x, y_w, y_l, index):
    """Print a sample triplet in a readable format for quality check."""
    print("\n" + "="*80)
    print(f"SAMPLE #{index}")
    print("="*80)
    print(f"PROMPT (x):\n{x}")
    print("-"*80)
    print(f"WINNING RESPONSE (y_w):\n{y_w}")
    print("-"*80)
    print(f"LOSING RESPONSE (y_l):\n{y_l}")
    print("="*80 + "\n")


async def generate_single_response(client: AsyncOpenAI, prompt: str, system_prompt: str, model_id: str, semaphore: asyncio.Semaphore) -> str:
    """Generate a single response using async API call."""
    async with semaphore:  # Limit concurrent requests
        max_retries = 3
        for attempt in range(max_retries):
            try:
                # Add a small delay between requests to avoid overwhelming the server
                if attempt > 0:
                    await asyncio.sleep(1 + attempt)
                
                response = await client.chat.completions.create(
                    model=model_id,
                    messages=[
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": prompt}
                    ],
                    max_tokens=500,
                    temperature=0.7,
                    top_p=1.0,
                    frequency_penalty=0,
                    presence_penalty=0
                )
                
                # Extract the response content
                if response.choices and len(response.choices) > 0:
                    content = response.choices[0].message.content
                    if content:
                        return content.strip()
                    else:
                        return "No response generated"
                else:
                    return "No response generated"
                    
            except Exception as e:
                error_msg = str(e)
                
                # Handle different types of errors
                if "rate_limit" in error_msg.lower():
                    print(f"Rate limit hit, waiting 30 seconds... (attempt {attempt + 1}/{max_retries})")
                    await asyncio.sleep(30)
                elif "auth" in error_msg.lower() or "401" in error_msg:
                    print(f"Authentication error - check API key and endpoint: {error_msg}")
                    return f"Authentication failed: {error_msg}"
                elif "connection" in error_msg.lower() or "timeout" in error_msg.lower():
                    print(f"Connection error on attempt {attempt + 1}/{max_retries}: {error_msg}")
                    if attempt < max_retries - 1:
                        wait_time = 2 ** attempt + 1  # Exponential backoff with base delay
                        await asyncio.sleep(wait_time)
                        continue
                else:
                    print(f"Unexpected error on attempt {attempt + 1}/{max_retries}: {error_msg}")
                    await asyncio.sleep(2)
                
                if attempt == max_retries - 1:
                    return f"Failed after {max_retries} attempts: {error_msg}"
        
        return "Failed to generate response: Max retries exceeded"


async def generate_responses_batch_async(client: AsyncOpenAI, prompts: List[str], system_prompt: str, model_id: str, max_concurrent: int = 10) -> List[str]:
    """Generate responses for a batch of prompts concurrently using async API calls."""
    semaphore = asyncio.Semaphore(max_concurrent)
    
    # Create tasks for all prompts
    tasks = [
        generate_single_response(client, prompt, system_prompt, model_id, semaphore)
        for prompt in prompts
    ]
    
    # Execute all tasks concurrently and preserve order
    results = await asyncio.gather(*tasks, return_exceptions=True)
    
    # Convert exceptions to error strings
    final_results = []
    for result in results:
        if isinstance(result, Exception):
            final_results.append(f"Failed to generate response: {str(result)}")
        else:
            final_results.append(result)
    
    return final_results


async def generate_pairwise_dataset(w=0.6, batch_size=16, output_dir="datasets/gpt4/", sample_interval=50, resume_from=None, max_samples=10000, openai_key=None, endpoint=None, model_id=None, max_concurrent=10):
    """Generate pairwise comparison dataset using custom OpenAI-compatible API with parallel processing."""
    
    # Initialize OpenAI client with custom endpoint
    if not openai_key:
        raise ValueError("API key is required")
    if not endpoint:
        raise ValueError("API endpoint is required")
    if not model_id:
        raise ValueError("Model ID is required")
    
    # Initialize both sync and async clients
    # Note: Some APIs expect the key with "Bearer " prefix, others without
    print(f"Using endpoint: {endpoint}")
    print(f"Using model: {model_id}")
    print(f"API key format: {openai_key[:10]}...")
    
    sync_client = OpenAI(
        api_key="Bearer " + openai_key,
        base_url=endpoint,
        timeout=60.0,  # 60 second timeout
        max_retries=3,
    )
    
    async_client = AsyncOpenAI(
        api_key="Bearer " + openai_key,
        base_url=endpoint,
        timeout=60.0,  # 60 second timeout
        max_retries=3,
    )
    print("OpenAI-compatible clients initialized successfully.")
    
    # Test API connection with sync client
    print("Testing API connection...")
    try:
        test_response = sync_client.chat.completions.create(
            model=model_id,
            messages=[{"role": "user", "content": "Hello"}],
            max_tokens=10
        )
        print("API connection test successful!")
    except Exception as e:
        print(f"API connection test failed with Bearer prefix: {e}")
        print("Trying without Bearer prefix...")
        
        # Try without Bearer prefix
        try:
            sync_client = OpenAI(
                api_key=openai_key,  # Without Bearer prefix
                base_url=endpoint,
                timeout=60.0,
                max_retries=3,
            )
            
            async_client = AsyncOpenAI(
                api_key=openai_key,  # Without Bearer prefix
                base_url=endpoint,
                timeout=60.0,
                max_retries=3,
            )
            
            test_response = sync_client.chat.completions.create(
                model=model_id,
                messages=[{"role": "user", "content": "Hello"}],
                max_tokens=10
            )
            print("API connection test successful without Bearer prefix!")
        except Exception as e2:
            print(f"API connection test failed without Bearer prefix: {e2}")
            print("Please check your API key, endpoint, and model ID.")
            sys.exit(1)
    
    try:
        # Load full Alpaca-GPT4 dataset (52K examples) instead of just evaluation subset
        print("Loading full Alpaca-GPT4 dataset (this may take a while)...")
        try:
            # Try to load the full dataset first
            dataset = load_dataset("vicgalle/alpaca-gpt4")
            print(f"Full dataset loaded with {len(dataset['train'])} examples")
            data_split = 'train'  # The main split in the full dataset
        except Exception as e:
            print(f"Error loading full dataset: {e}")
            print("Falling back to evaluation subset...")
            dataset = load_dataset("tatsu-lab/alpaca_eval", "alpaca_eval", trust_remote_code=True)
            print(f"Evaluation subset loaded with {len(dataset['eval'])} examples")
            data_split = 'eval'  # The main split in the eval dataset
        
        # Filter for instructions only (no input field or empty input)
        filtered_data = []
        total_examples = len(dataset[data_split])
        print(f"Filtering {total_examples} examples...")
        
        for item in tqdm(dataset[data_split], desc="Filtering dataset"):
            if 'input' not in item or not item['input']:
                filtered_data.append(item['instruction'])
        
        print(f"Filtered dataset contains {len(filtered_data)} examples")
        
        # Optional: Limit dataset size for testing/development
        max_examples = min(float('inf'), max_samples)
        if len(filtered_data) > max_examples:
            print(f"Limiting dataset to {max_examples} examples")
            filtered_data = filtered_data[:max_examples]
        
        # Create output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)
        
        # Initialize pairwise data and starting index
        pairwise_data = []
        start_idx = 0
        
        # Check for existing intermediate files or resume file
        if resume_from:
            if os.path.exists(resume_from):
                print(f"Resuming from file: {resume_from}")
                with open(resume_from, 'r', encoding='utf-8') as f:
                    pairwise_data = json.load(f)
                start_idx = len(pairwise_data)
                print(f"Loaded {start_idx} existing pairs")
            else:
                print(f"Warning: Resume file {resume_from} not found. Starting from scratch.")
        else:
            # Look for the most recent intermediate file
            intermediate_files = [f for f in os.listdir(output_dir) if f.startswith("pairwise_dataset_intermediate_")]
            if intermediate_files:
                latest_file = max(intermediate_files, key=lambda x: int(x.split('_')[-1].split('.')[0]))
                resume_path = os.path.join(output_dir, latest_file)
                print(f"Found latest intermediate file: {resume_path}")
                with open(resume_path, 'r', encoding='utf-8') as f:
                    pairwise_data = json.load(f)
                start_idx = len(pairwise_data)
                print(f"Loaded {start_idx} existing pairs")
        
        # Set concurrent request limit
        print(f"Using max concurrent requests: {max_concurrent}")
        
        # Start timing
        start_time = time.time()
        
        # Process in batches for efficiency
        pbar = tqdm(range(start_idx, len(filtered_data), batch_size), desc="Processing batches")
        for i in pbar:
            # Get the current batch of prompts
            batch_prompts = filtered_data[i:i+batch_size]
            
            # Generate both groups' responses concurrently
            batch_start = time.time()
            print(f"\nGenerating responses for batch {i//batch_size + 1} (both groups in parallel)...")
            
            # Run both groups concurrently
            group1_task = generate_responses_batch_async(
                async_client,
                batch_prompts,
                "Generate a response that can be easily understood by an elementary school student.",
                model_id,
                max_concurrent=max_concurrent
            )
            
            group2_task = generate_responses_batch_async(
                async_client,
                batch_prompts,
                "Generate a response that only a PhD Student in that specific field could understand.",
                model_id,
                max_concurrent=max_concurrent
            )
            
            # Wait for both groups to complete
            group1_responses, group2_responses = await asyncio.gather(group1_task, group2_task)
            
            # Report batch timing
            batch_time = time.time() - batch_start
            examples_per_second = len(batch_prompts) / batch_time
            print(f"Batch processed in {batch_time:.2f}s ({examples_per_second:.2f} examples/second)")
            
            # Add a small delay between batches to avoid overwhelming the server
            await asyncio.sleep(0.5)
            
            # Calculate and report ETA
            elapsed_time = time.time() - start_time
            processed_examples = min(i + batch_size, len(filtered_data)) - start_idx
            if processed_examples > 0:
                avg_time_per_example = elapsed_time / processed_examples
                remaining_examples = len(filtered_data) - (i + batch_size)
                eta_seconds = avg_time_per_example * remaining_examples
                
                # Convert to hours, minutes, seconds
                eta_hours = int(eta_seconds // 3600)
                eta_minutes = int((eta_seconds % 3600) // 60)
                eta_seconds = int(eta_seconds % 60)
                
                print(f"ETA: {eta_hours}h {eta_minutes}m {eta_seconds}s")
            
            # Create pairwise data
            for j, prompt in enumerate(batch_prompts):
                if j < len(group1_responses) and j < len(group2_responses):
                    # Determine which is preferred based on probability w
                    if random.random() < w:
                        y_w, y_l = group1_responses[j], group2_responses[j]
                    else:
                        y_w, y_l = group2_responses[j], group1_responses[j]
                    
                    # Create data point with both winning/losing and group labels
                    data_point = {
                        "x": prompt,
                        "y_w": y_w,
                        "y_l": y_l,
                        "y_1": group1_responses[j],
                        "y_2": group2_responses[j]
                    }
                    pairwise_data.append(data_point)
                    
                    # Print sample for quality check at intervals
                    sample_count = len(pairwise_data)
                    if sample_interval > 0 and sample_count % sample_interval == 1:
                        print_sample_triplet(prompt, y_w, y_l, sample_count)
                else:
                    print(f"Warning: Missing response for prompt {j+1} in batch {i//batch_size + 1}")
            
            # Update progress bar with current samples and progress
            pbar.set_postfix({
                "samples": f"{len(pairwise_data)}/{max_samples}",
                "progress": f"{(len(pairwise_data)/max_samples)*100:.1f}%"
            })
            
            # Save intermediate results every 10 batches
            if (i // batch_size + 1) % 10 == 0:
                intermediate_file = os.path.join(output_dir, f"pairwise_dataset_intermediate_{len(pairwise_data)}.json")
                with open(intermediate_file, 'w', encoding='utf-8') as f:
                    json.dump(pairwise_data, f, ensure_ascii=False, indent=2)
                print(f"Intermediate results saved to {intermediate_file}")
            
            # Check if we've reached the maximum number of samples
            if len(pairwise_data) >= max_samples:
                print(f"\nReached maximum number of samples ({max_samples}). Stopping generation.")
                break
        
        # Save the complete dataset
        if pairwise_data:  # Only save if we have data
            output_file = os.path.join(output_dir, f"pairwise_dataset_gpt4_w{w}_n{len(pairwise_data)}.json")
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(pairwise_data, f, ensure_ascii=False, indent=2)
            
            print(f"Dataset saved to {output_file}")
            print(f"Total number of pairs: {len(pairwise_data)}")
            
            # Report total time
            total_time = time.time() - start_time
            hours = int(total_time // 3600)
            minutes = int((total_time % 3600) // 60)
            seconds = int(total_time % 60)
            print(f"Total processing time: {hours}h {minutes}m {seconds}s")
            print(f"Average processing speed: {len(pairwise_data) / total_time:.2f} examples/second")
        else:
            print("No data was generated successfully.")
    
    except Exception as e:
        print(f"An error occurred: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate pairwise comparison dataset using custom OpenAI-compatible API")
    parser.add_argument("--w", type=float, default=0.6, help="Probability of selecting y_1 as preferred")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size for processing")
    parser.add_argument("--output_dir", type=str, default="datasets/gpt4/1/", help="Output directory for the dataset")
    parser.add_argument("--sample_interval", type=int, default=100, help="How often to print samples for quality checking (default: every 50)")
    parser.add_argument("--max_samples", type=int, default=30000, help="Maximum number of samples to generate (default: 10000)")
    parser.add_argument("--openai_key", type=str, default="", help="API key for authentication")
    parser.add_argument("--endpoint", type=str, default="", help="API endpoint URL")
    parser.add_argument("--model_id", type=str, default="gpt-4.1", help="Model ID to use for generation")
    parser.add_argument("--max_concurrent", type=int, default=4, help="Maximum number of concurrent API requests")
    
    args = parser.parse_args()
    
    # Run the async function
    asyncio.run(generate_pairwise_dataset(
        args.w, 
        args.batch_size, 
        args.output_dir, 
        args.sample_interval, 
        max_samples=args.max_samples,
        openai_key=args.openai_key,
        endpoint=args.endpoint,
        model_id=args.model_id,
        max_concurrent=args.max_concurrent
    )) 