#!/usr/bin/env python3
"""
Batch Management Tool for OpenAI Batch Processing

This tool helps manage OpenAI batch jobs for the math olympiad project.
"""

import asyncio
import argparse
import json
import os
import sys
from datetime import datetime
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path(__file__).parent / 'src'))

from olym_gen.generator.base_generator import GeneratorBase
from olym_gen.utils.utils import get_logger

logger = get_logger()


class BatchManager:
    """Utility class to manage OpenAI batch operations."""
    
    def __init__(self, provider):
        
        self.generator = GeneratorBase(provider=provider)
    
    async def list_batches(self, limit: int = 10):
        """List recent batch jobs."""
        try:
            batches = await self.generator.client.batches.list(limit=limit)
            
            print(f"{'Batch ID':<30} {'Status':<12} {'Created':<20} {'Requests':<10}")
            print("-" * 75)
            
            for batch in batches.data:
                created = datetime.fromtimestamp(int(batch.created_at)).strftime('%Y-%m-%d %H:%M')
                request_count = "N/A"
                if hasattr(batch, 'request_counts') and batch.request_counts:
                    total = getattr(batch.request_counts, 'total', 0)
                    completed = getattr(batch.request_counts, 'completed', 0)
                    request_count = f"{completed}/{total}"
                
                print(f"{batch.id:<30} {batch.status:<12} {created:<20} {request_count:<10}")
                
        except Exception as e:
            logger.error(f"Failed to list batches: {e}")
    
    async def get_batch_info(self, batch_id: str):
        """Get detailed information about a specific batch."""
        try:
            batch = await self.generator.client.batches.retrieve(batch_id)
            logger.debug(f"Batch details: {batch}")
            
            print(f"Batch ID: {batch.id}")
            print(f"Status: {batch.status}")
            print(f"Created: {datetime.fromtimestamp(batch.created_at)}")
            
            if hasattr(batch, 'in_progress_at') and batch.in_progress_at:
                print(f"Started: {datetime.fromtimestamp(batch.in_progress_at)}")
            
            if hasattr(batch, 'completed_at') and batch.completed_at:
                print(f"Completed: {datetime.fromtimestamp(batch.completed_at)}")
            
            if hasattr(batch, 'failed_at') and batch.failed_at:
                print(f"Failed: {datetime.fromtimestamp(batch.failed_at)}")
            
            print(f"Completion Window: {batch.completion_window}")
            
            if hasattr(batch, 'request_counts') and batch.request_counts:
                counts = batch.request_counts
                print(f"Requests - Total: {getattr(counts, 'total', 0)}, "
                      f"Completed: {getattr(counts, 'completed', 0)}, "
                      f"Failed: {getattr(counts, 'failed', 0)}")
            
            if hasattr(batch, 'usage') and batch.usage:
                usage = batch.usage
                print(f"Usage - Total tokens: {usage.get('total_tokens', 0)}")

            if hasattr(batch, 'metadata') and batch.metadata:
                print("Metadata:")
                for key, value in batch.metadata.items():
                    print(f"  {key}: {value}")
            
            if batch.status == "failed" and hasattr(batch, 'errors'):
                print("Errors:")
                for error in batch.errors:
                    print(f"  {error}")
                    
        except Exception as e:
            logger.error(f"Failed to get batch info: {e}")
    
    async def cancel_batch(self, batch_id: str):
        """Cancel a batch job."""
        try:
            # First check current status
            batch = await self.generator.client.batches.retrieve(batch_id)
            print(f"Current status: {batch.status}")
            logger.debug(f"Batch details before cancellation: {batch}")
            
            if batch.status not in ["validating", "in_progress"]:
                print(f"Cannot cancel batch in status '{batch.status}'")
                return
            
            # Cancel the batch
            cancelled = await self.generator.client.batches.cancel(batch_id)
            print(f"Cancellation requested. New status: {cancelled.status}")
            logger.debug(f"Batch details after cancellation: {cancelled}")
            
        except Exception as e:
            logger.error(f"Failed to cancel batch: {e}")
    
    async def download_file(self, file_id: str, output_file: str):
        """Download a file by its ID."""
        try:
            # Download the file
            file_response = await self.generator.client.files.content(file_id)
            file_content = file_response.content
            
            # Save to file
            with open(output_file, 'wb') as f:
                f.write(file_content)
            
            print(f"File downloaded to: {output_file}")
            
        except Exception as e:
            logger.error(f"Failed to download file: {e}")

    async def download_results(self, batch_id: str, output_file: str):
        """Download results from a completed batch."""
        try:
            batch = await self.generator.client.batches.retrieve(batch_id)
            
            if batch.status != "completed":
                print(f"Batch is not completed. Current status: {batch.status}")
                return
            
            if not batch.output_file_id:
                print("No output file available for this batch")
                return
            
            # Download the results
            file_response = await self.generator.client.files.content(batch.output_file_id)
            results_content = file_response.content.decode('utf-8')
            
            # Save to file
            with open(output_file, 'w', encoding='utf-8') as f:
                f.write(results_content)
            
            print(f"Results downloaded to: {output_file}")
            
            # Show summary
            lines = results_content.strip().split('\n')
            success_count = 0
            error_count = 0
            
            for line in lines:
                try:
                    result = json.loads(line)
                    if result.get('error'):
                        error_count += 1
                    else:
                        success_count += 1
                except:
                    continue
            
            print(f"Summary: {success_count} successful, {error_count} failed")
            
        except Exception as e:
            logger.error(f"Failed to download results: {e}")


async def main():
    parser = argparse.ArgumentParser(description="OpenAI Batch Management Tool")
    parser.add_argument('--provider', choices=["openai_batch", "siliconflow_batch", "tencent_batch", "ali_batch", "gemini_batch"], default="openai_batch", help="API provider to use")

    subparsers = parser.add_subparsers(dest='command', help='Available commands')
    
    # List batches
    list_parser = subparsers.add_parser('list', help='List recent batch jobs')
    list_parser.add_argument('--limit', type=int, default=10, help='Number of batches to show')
    
    # Get batch info
    info_parser = subparsers.add_parser('info', help='Get detailed batch information')
    info_parser.add_argument('--batch_id', help='Batch ID to inspect')
    
    # Cancel batch
    cancel_parser = subparsers.add_parser('cancel', help='Cancel a batch job')
    cancel_parser.add_argument('--batch_id', help='Batch ID to cancel')
    
    # Download results
    download_parser = subparsers.add_parser('download', help='Download batch results')
    download_parser.add_argument('--batch_id', help='Batch ID to download from')
    download_parser.add_argument('--output_file', help='Output file path')
    
    # Download file by ID
    file_parser = subparsers.add_parser('file', help='Download a file by its ID')
    file_parser.add_argument('--file_id', help='File ID to download')
    file_parser.add_argument('--output_file', help='Output file path')

    args = parser.parse_args()
    
    if not args.command:
        parser.print_help()
        return
    
    try:
        manager = BatchManager(args.provider)
        
        if args.command == 'list':
            await manager.list_batches(args.limit)
        elif args.command == 'info':
            await manager.get_batch_info(args.batch_id)
        elif args.command == 'cancel':
            await manager.cancel_batch(args.batch_id)
        elif args.command == 'download':
            await manager.download_results(args.batch_id, args.output_file)
        elif args.command == 'file':
            await manager.download_file(args.file_id, args.output_file)

    except Exception as e:
        logger.error(f"Command failed: {e}")
        sys.exit(1)


if __name__ == "__main__":
    asyncio.run(main())
