import os
import sys
import json
import asyncio
import aiohttp
import aiofiles
from tqdm import tqdm
import time
import random
import logging


logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

async def download_image(session, url, output_path, semaphore):
    headers = {
        'User-Agent': f'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/{random.randint(80, 91)}.0.{random.randint(1000, 5000)}.{random.randint(10, 200)} Safari/537.36',
        'Accept': 'image/avif,image/webp,image/apng,image/svg+xml,image/*,*/*;q=0.8',
        'Accept-Language': 'en-US,en;q=0.9',
        'Referer': 'https://www.amazon.com/'
    }
    
    async with semaphore: 
        for attempt in range(3):
            try:
                await asyncio.sleep(random.uniform(0.05, 0.2))
                
                async with session.get(url, headers=headers) as response:
                    if response.status == 200:
                        async with aiofiles.open(output_path, 'wb') as f:
                            chunk_size = 8192
                            async for chunk in response.content.iter_chunked(chunk_size):
                                await f.write(chunk)
                        return True
                    else:
                        if attempt < 2: 
                            await asyncio.sleep(2 ** attempt + random.random())
            except Exception as e:
                logger.debug(f"Download attempt {attempt+1} failed for {url}: {e}")
                if attempt < 2:
                    await asyncio.sleep(2 ** attempt + random.random())
        
        return False

async def process_dataset(dataset_name, concurrency_limit=20, chunk_size=2000):
    json_path = f"/home/yqiao47/dataset/{dataset_name}/meta_{dataset_name}.jsonl"
    image_folder = f"/home/yqiao47/dataset/{dataset_name}/images/"

    os.makedirs(image_folder, exist_ok=True)
    
    if not os.path.exists(json_path):
        logger.error(f"File not found: {json_path}")
        return
    
    logger.info(f"Processing dataset: {dataset_name}")
    logger.info(f"Reading from: {json_path}")
    logger.info(f"Saving images to: {image_folder}")
    
    # Counters
    stats = {
        "total": 0,
        "success": 0,
        "skipped": 0,
        "no_image_url": 0,
        "error": 0
    }
    
    try:
        with open(json_path, 'r') as f:
            for _ in f:
                stats["total"] += 1
    except Exception as e:
        logger.error(f"Error counting items: {e}")
        return
    
    semaphore = asyncio.Semaphore(concurrency_limit)
    
    offset = 0
    total_chunks = (stats["total"] + chunk_size - 1) // chunk_size
    
    for chunk_idx in range(total_chunks):
        logger.info(f"Processing chunk {chunk_idx+1}/{total_chunks}")
        
        items = []
        with open(json_path, 'r') as f:
            for _ in range(offset):
                next(f, None)
            
            for i in range(chunk_size):
                line = next(f, None)
                if line is None:
                    break
                items.append(line)
        
        offset += len(items)
        
        download_tasks = []
        
        for line in items:
            try:
                item = json.loads(line.strip())
                
                parent_asin = item.get('parent_asin')
                if not parent_asin:
                    stats["error"] += 1
                    continue
                
                image_path = os.path.join(image_folder, f"{parent_asin}_MAIN.jpg")
                if os.path.exists(image_path):
                    stats["skipped"] += 1
                    continue
                
                image_url = None
                images = item.get('images', [])
                
                if isinstance(images, list):
                    main_image = None
                    for img in images:
                        if img.get('variant', '').strip().upper() == 'MAIN':
                            main_image = img
                            break
                    
                    if main_image:
                        image_url = main_image.get('large')
                        if not image_url:
                            image_url = main_image.get('hi_res')
                    
                    if not image_url and images:
                        image_url = images[0].get('large') or images[0].get('hi_res')
                
                elif isinstance(images, dict):
                    hi_res = images.get('hi_res', [])
                    large = images.get('large', [])
                    variants = images.get('variant', [])
                    
                    main_index = -1
                    if variants:
                        try:
                            main_index = variants.index('MAIN')
                        except ValueError:
                            main_index = 0 if variants else -1
                    
                    if main_index != -1:
                        if large and len(large) > main_index:
                            image_url = large[main_index]
                        elif hi_res and len(hi_res) > main_index: 
                            image_url = hi_res[main_index]
                
                if image_url:
                    download_tasks.append((image_url, image_path))
                else:
                    stats["no_image_url"] += 1
                    
            except json.JSONDecodeError:
                stats["error"] += 1
            except Exception as e:
                logger.error(f"Unexpected error parsing item: {e}")
                stats["error"] += 1
        
        if download_tasks:
            connector = aiohttp.TCPConnector(
                limit=concurrency_limit,
                limit_per_host=8,
                ssl=False,
                ttl_dns_cache=300
            )
            
            async with aiohttp.ClientSession(connector=connector) as session:
                tasks = []
                for url, path in download_tasks:
                    task = asyncio.create_task(download_image(session, url, path, semaphore))
                    tasks.append(task)
                
                with tqdm(total=len(tasks), desc=f"Chunk {chunk_idx+1}/{total_chunks}") as progress:
                    for task in asyncio.as_completed(tasks):
                        try:
                            success = await task
                            if success:
                                stats["success"] += 1
                            else:
                                stats["error"] += 1
                        except Exception as e:
                            logger.error(f"Task error: {e}")
                            stats["error"] += 1
                        progress.update(1)
        
        if chunk_idx < total_chunks - 1:
            delay = random.uniform(1, 2) 
            logger.info(f"Waiting {delay:.1f}s before next chunk...")
            await asyncio.sleep(delay)
    
    logger.info(f"\nDownload complete for {dataset_name}:")
    logger.info(f"  Success: {stats['success']}")
    logger.info(f"  Skipped (already exists): {stats['skipped']}")
    logger.info(f"  No image URL found: {stats['no_image_url']}")
    logger.info(f"  Failed downloads: {stats['error']}")
    logger.info(f"  Total processed: {stats['total']}")
    
    return stats

def main():
    if len(sys.argv) < 2:
        print("Usage: python download_images_async.py DATASET_NAME [CONCURRENCY]")
        sys.exit(1)
    
    dataset_name = sys.argv[1]
    concurrency = int(sys.argv[2]) if len(sys.argv) > 2 else 50  
    
    asyncio.run(process_dataset(dataset_name, concurrency_limit=concurrency))

if __name__ == "__main__":
    main()