#!/usr/bin/env python3

import torch
import os
import sys
import numpy as np
from pathlib import Path
import time
from tqdm import tqdm
import torchvision
import json
import random

# Add the current directory to the path to import VQ-Diffusion
sys.path.append(os.path.dirname(__file__))

from inference_VQ_Diffusion import VQ_Diffusion


def load_coco_captions(coco_path):
    """Load captions from MS COCO dataset JSON file."""
    try:
        with open(coco_path, 'r') as f:
            data = json.load(f)
        
        # Extract captions from the annotations
        captions = []
        for annotation in data['annotations']:
            captions.append(annotation['caption'])
        
        print(f"Loaded {len(captions)} captions from MS COCO dataset")
        return captions
    except Exception as e:
        print(f"Error loading COCO captions: {e}")
        return []


def setup_environment(config):
    """Sets up the environment for generation, including device, seeds, and directories."""
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available. Exiting.")
    
    # Handle GPU configuration
    if config.GPU_ID is not None:
        # If GPU_ID is set, use it to set CUDA_VISIBLE_DEVICES
        os.environ['CUDA_VISIBLE_DEVICES'] = str(config.GPU_ID)
        print(f"Using GPU {config.GPU_ID}")
    else:
        # Check if CUDA_VISIBLE_DEVICES is already set in environment
        cuda_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None)
        if cuda_devices:
            print(f"Using GPU(s) from environment: {cuda_devices}")
        else:
            print("Using default GPU configuration")
    
    device = torch.device(config.DEVICE)
    torch.manual_seed(config.SEED)
    np.random.seed(config.SEED)
    torch.backends.cudnn.benchmark = True
    
    Path(config.OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
    print(f"Generated images will be saved to: {config.OUTPUT_DIR}")
    return device


def generate_images(model, prompts, truncation_rate, save_root, replicate=1, guidance_scale=1.0, return_tokens=False, start_index=0):
    """Generate images from a list of prompts."""
    os.makedirs(save_root, exist_ok=True)

    model.model.guidance_scale = guidance_scale

    data_i = {}
    data_i['text'] = prompts
    data_i['image'] = None
    data_i['batch_size'] = len(prompts)

    print(f"Generating {len(prompts)} prompts with {replicate} image(s) per prompt...")

    with torch.no_grad():
        model_out = model.model.generate_content(
            batch=data_i,
            filter_ratio=0,
            replicate=replicate,
            content_ratio=1,
            return_att_weight=False,
            sample_type="top"+str(truncation_rate)+'r',
        )
    
    # Process results
    content = model_out['content']
    if isinstance(content, list):
        content = [c if isinstance(c, torch.Tensor) else torch.as_tensor(c) for c in content]
        content = torch.stack(content, dim=0)
    elif not isinstance(content, torch.Tensor):
        content = torch.as_tensor(content)
    if content.dtype != torch.float32:
        content = content.float()
    max_val = content.max()
    if torch.isfinite(max_val) and max_val > 1.0:
        content = content / 255.0
    
    # Get tokens if requested
    tokens = None
    if return_tokens and 'content_token' in model_out:
        tokens = model_out['content_token']
    
    # Save images
    for i in range(content.shape[0]):
        # Use sequential numbering starting from start_index
        global_index = start_index + i
        
        save_path = os.path.join(save_root, f"{global_index:06d}.png")
        torchvision.utils.save_image(content[i].clamp(0, 1), save_path)
        print(f"Saved image {global_index+1}: {save_path}")
    
    print(f"✅ Successfully saved {content.shape[0]} images")
    
    return content, tokens


def main():
    """Main function to set up model and generate images."""
    
    class Config:
        def __init__(self):
            # GPU Configuration
            self.GPU_ID = os.environ.get('CUDA_VISIBLE_DEVICES', None)
            self.DEVICE = 'cuda'
            self.SEED = 5
            
            # Model selection
            self.MODEL_NAME = 'ithq'  # 'ithq', 'coco', 'cub', 'cc', 'imagenet'
            
            # MSCOCO dataset path
            self.MSCOCO_DATA_PATH = "/data/datasets/mscoco2014val/annotations/captions_val2014.json"
            
            # Batch size control
            self.NUM_PROMPTS = 50000  # Number of COCO prompts to use
            self.REPLICATE = 1    # Images per prompt (1 = 1 image per prompt)
            self.BATCH_SIZE =25   # Number of prompts to process at once
            
            # Class-based generation (for ImageNet models)
            self.CLASS_LABEL = None  # Only used for ImageNet model (e.g., 407 for teddy bear)
            self.USE_RANDOM_CLASSES = True  # Sample random classes instead of fixed class
            
            # Quality and speed parameters
            self.TRUNCATION_RATE = 1.0
            self.GUIDANCE_SCALE = 5.0
            self.INFER_SPEED = None  # Speed multiplier (0.1-10), None for default
            
            # Advanced parameters
            self.PRIOR_RULE = 2  # 0: VQ-Diffusion v1, 1: high-quality, 2: purity
            self.PRIOR_WEIGHT = 0  # Probability adjust parameter
            self.LEARNABLE_CF = True  # Use learnable classifier-free
            
            # Token export settings
            self.EXPORT_TOKENS = True  # Export discrete tokens alongside images
            
            # Output settings
            self.OUTPUT_DIR = "generated_images"
            
            # Set up model-specific configurations
            model_configs = {
                'ithq': {
                    'config': 'configs/ithq.yaml',
                    'path': '/checkpoints/pretrained_model/ithq_learnable.pth',
                    'imagenet_cf': False,
                    'type': 'text'
                },
                'coco': {
                    'config': '/checkpoints/pretrained_model/config_text.yaml',
                    'path': '/checkpoints/pretrained_model/coco_learnable.pth',
                    'imagenet_cf': False,
                    'type': 'text'
                },
                'cub': {
                    'config': '/checkpoints/pretrained_model/config_text.yaml',
                    'path': '/checkpoints/pretrained_model/cub_learnable.pth',
                    'imagenet_cf': False,
                    'type': 'text'
                },
                'cc': {
                    'config': '/checkpoints/pretrained_model/config_text.yaml',
                    'path': '/checkpoints/pretrained_model/cc_learnable.pth',
                    'imagenet_cf': False,
                    'type': 'text'
                },
                'imagenet': {
                    'config': 'configs/imagenet.yaml',
                    'path': '/checkpoints/pretrained_model/imagenet_learnable.pth',
                    'imagenet_cf': True,
                    'type': 'class'
                }
            }
            
            hp_folder = f"{self.MODEL_NAME}_np{self.NUM_PROMPTS}_pr{self.PRIOR_RULE}_seed{self.SEED}"
            self.OUTPUT_DIR = os.path.join(self.OUTPUT_DIR, hp_folder)

            # Validate MSCOCO data path
            if not os.path.exists(self.MSCOCO_DATA_PATH):
                raise FileNotFoundError(f"MSCOCO data path not found: {self.MSCOCO_DATA_PATH}")
            
            if self.MODEL_NAME not in model_configs:
                raise ValueError(f"Unknown model: {self.MODEL_NAME}. Available models: {list(model_configs.keys())}")
            
            model_config = model_configs[self.MODEL_NAME]
            self.CONFIG_PATH = model_config['config']
            self.MODEL_PATH = model_config['path']
            self.IMAGENET_CF = model_config['imagenet_cf']
            self.MODEL_TYPE = model_config['type']
            
            # Validate parameters
            if self.MODEL_TYPE == 'class' and not self.USE_RANDOM_CLASSES and self.CLASS_LABEL is None:
                raise ValueError("CLASS_LABEL must be provided for ImageNet model when USE_RANDOM_CLASSES is False")
            
            if self.TRUNCATION_RATE < 0 or self.TRUNCATION_RATE > 1:
                raise ValueError("TRUNCATION_RATE must be between 0 and 1")
            
            if self.MODEL_TYPE == 'text':
                if self.NUM_PROMPTS < 1:
                    raise ValueError("NUM_PROMPTS must be at least 1")
                # Calculate total batch size for text models
                self.TOTAL_BATCH_SIZE = self.NUM_PROMPTS * self.REPLICATE
                print(f"Batch size configuration: {self.NUM_PROMPTS} prompts × {self.REPLICATE} images = {self.TOTAL_BATCH_SIZE} total images")
            else:
                # For class-based models
                if self.USE_RANDOM_CLASSES:
                    # Use NUM_PROMPTS for number of random classes
                    self.TOTAL_BATCH_SIZE = self.NUM_PROMPTS * self.REPLICATE
                    print(f"Batch size configuration: {self.NUM_PROMPTS} random classes × {self.REPLICATE} images = {self.TOTAL_BATCH_SIZE} total images")
                else:
                    # Use fixed class with replicate
                    self.TOTAL_BATCH_SIZE = self.REPLICATE
                    print(f"Batch size configuration: 1 class × {self.REPLICATE} images = {self.TOTAL_BATCH_SIZE} total images")
            
            if self.REPLICATE < 1:
                raise ValueError("REPLICATE must be at least 1")
            
            if self.GUIDANCE_SCALE < 0:
                raise ValueError("GUIDANCE_SCALE must be non-negative")
            
            if self.INFER_SPEED is not None and (self.INFER_SPEED < 0.1 or self.INFER_SPEED > 10):
                raise ValueError("INFER_SPEED must be between 0.1 and 10")
            
            if self.PRIOR_RULE not in [0, 1, 2]:
                raise ValueError("PRIOR_RULE must be 0, 1, or 2")
    
    config = Config()

    # Set up environment
    device = setup_environment(config)
    
    # Initialize model
    print("Initializing VQ-Diffusion model...")
    model = VQ_Diffusion(
        config=config.CONFIG_PATH,
        path=config.MODEL_PATH,
        imagenet_cf=config.IMAGENET_CF
    )
    
    print("=== Starting image generation ===")
    
    if config.MODEL_TYPE == 'text':
        # Load COCO captions for text-based models
        print("Loading MS COCO captions...")
        coco_captions = load_coco_captions(config.MSCOCO_DATA_PATH)
        if not coco_captions:
            raise RuntimeError("Failed to load COCO captions. Please check the data path.")
        
        # Select random prompts from COCO
        prompts = random.sample(coco_captions, config.NUM_PROMPTS)
        print(f"Selected {len(prompts)} COCO prompts to process in batches of {config.BATCH_SIZE}")
        
        # Process prompts in batches - handle full batches and remainder separately
        total_images_generated = 0
        all_tokens = [] if config.EXPORT_TOKENS else None
        
        # Process full batches first
        full_batches = len(prompts) // config.BATCH_SIZE
        for batch_idx in range(full_batches):
            batch_start = batch_idx * config.BATCH_SIZE
            batch_end = batch_start + config.BATCH_SIZE
            batch_prompts = prompts[batch_start:batch_end]
            
            print(f"\n--- Processing batch {batch_idx + 1}/{full_batches + (1 if len(prompts) % config.BATCH_SIZE > 0 else 0)} (prompts {batch_start+1}-{batch_end}) ---")
            
            # Generate images for this batch
            content, tokens = generate_images(
                model=model,
                prompts=batch_prompts,
                truncation_rate=config.TRUNCATION_RATE,
                save_root=config.OUTPUT_DIR,
                replicate=config.REPLICATE,
                guidance_scale=config.GUIDANCE_SCALE,
                return_tokens=config.EXPORT_TOKENS,
                start_index=total_images_generated
            )
            
            # Accumulate tokens
            if config.EXPORT_TOKENS and tokens is not None:
                if isinstance(tokens, torch.Tensor):
                    all_tokens.append(tokens.detach().cpu())
                elif isinstance(tokens, (list, tuple)):
                    stacked = torch.stack([t.detach().cpu() for t in tokens], dim=0)
                    all_tokens.append(stacked)
            
            total_images_generated += config.BATCH_SIZE * config.REPLICATE
        
        # Handle remaining prompts (if any) with batch size = 1 to avoid dimension issues
        remaining_prompts = len(prompts) % config.BATCH_SIZE
        if remaining_prompts > 0:
            remainder_start = full_batches * config.BATCH_SIZE
            remainder_prompts = prompts[remainder_start:]
            
            print(f"\n--- Processing remaining {remaining_prompts} prompts individually to avoid batch size issues ---")
            
            # Process remaining prompts one by one
            for i, prompt in enumerate(remainder_prompts):
                print(f"Processing remaining prompt {i+1}/{remaining_prompts}: {prompt[:50]}...")
                
                content, tokens = generate_images(
                    model=model,
                    prompts=[prompt],  # Single prompt
                    truncation_rate=config.TRUNCATION_RATE,
                    save_root=config.OUTPUT_DIR,
                    replicate=config.REPLICATE,
                    guidance_scale=config.GUIDANCE_SCALE,
                    return_tokens=config.EXPORT_TOKENS,
                    start_index=total_images_generated
                )
                
                # Accumulate tokens
                if config.EXPORT_TOKENS and tokens is not None:
                    if isinstance(tokens, torch.Tensor):
                        all_tokens.append(tokens.detach().cpu())
                    elif isinstance(tokens, (list, tuple)):
                        stacked = torch.stack([t.detach().cpu() for t in tokens], dim=0)
                        all_tokens.append(stacked)
                
                total_images_generated += config.REPLICATE
            print(f"Completed batch. Total images generated so far: {total_images_generated}")
        
        # Save a single stacked token tensor at the end
        if config.EXPORT_TOKENS and all_tokens:
            final_tokens = torch.cat(all_tokens, dim=0)
            tokens_path = os.path.join(config.OUTPUT_DIR, "all_tokens.pt")
            torch.save(final_tokens, tokens_path)
            print(f"Saved all tokens tensor: {tokens_path} with shape {tuple(final_tokens.shape)}")
    else:
        # Generate images for ImageNet models
        if config.USE_RANDOM_CLASSES:
            # Sample random classes
            # ImageNet has 1000 classes (0-999)
            random_classes = random.sample(range(1000), config.NUM_PROMPTS)
            print(f"Selected {len(random_classes)} random ImageNet classes:")
            for i, class_id in enumerate(random_classes):
                print(f"  {i+1}. Class {class_id}")
            
            # Generate images for each random class
            for i, class_id in enumerate(random_classes):
                print(f"Generating images for class {class_id}...")
                
                model.inference_generate_sample_with_class(
                    text=class_id,
                    truncation_rate=config.TRUNCATION_RATE,
                    save_root=config.OUTPUT_DIR,
                    batch_size=config.REPLICATE,
                    guidance_scale=config.GUIDANCE_SCALE,
                    return_tokens=config.EXPORT_TOKENS
                )
        else:
            # Generate images for fixed class
            print(f"Generating images for class {config.CLASS_LABEL}")
            
            model.inference_generate_sample_with_class(
                text=config.CLASS_LABEL,
                truncation_rate=config.TRUNCATION_RATE,
                save_root=config.OUTPUT_DIR,
                batch_size=config.REPLICATE,
                guidance_scale=config.GUIDANCE_SCALE,
                return_tokens=config.EXPORT_TOKENS
            )
    
    print(f"=== Generation completed successfully ===")
    print(f"Images saved to: {config.OUTPUT_DIR}")
    print(f"Total images generated: {config.TOTAL_BATCH_SIZE}")


if __name__ == '__main__':
    main()
 