import json
import argparse
import os
import torch
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
from rich import print as rprint
from diffusers import FluxPipeline
from tqdm import tqdm

os.environ['HF_HOME'] = '/workspace/cache/'

def number_to_words(n):
    """Convert numbers to words for prompts."""
    num_to_word = {
        1: "one", 2: "two", 3: "three", 4: "four", 5: "five", 
        6: "six", 7: "seven", 8: "eight", 9: "nine", 10: "ten", 
        11: "eleven", 12: "twelve", 13: "thirteen", 14: "fourteen", 
        15: "fifteen", 16: "sixteen", 17: "seventeen", 18: "eighteen", 
        19: "nineteen", 20: "twenty", 30: "thirty", 40: "forty", 
        50: "fifty"
    }
    
    if n <= 20:
        return num_to_word[n]
    elif 21 <= n <= 50:
        tens = (n // 10) * 10
        ones = n % 10
        if ones == 0:
            return num_to_word[tens]
        else:
            return f"{num_to_word[tens]}-{num_to_word[ones]}"
    else:
        return "Number out of range"


def filter_processed_prompts(prompts, output_dir):
    """Filter out already processed prompts."""
    filtered_prompts = []
    existing_files = set(os.listdir(output_dir))
    
    processed_cases = set()
    for f in existing_files:
        if 'tmp' in f:
            key = f.split('_tmp')[0].split('case_')[1]
        else:
            key = f.split('_output')[0].split('case_')[1]
        processed_cases.add(key)
    
    for prompt in prompts:
        gt_count = prompt["gt_count"]
        prompt_id = prompt.get("prompt_id", 0)
        
        animals_str = prompt["animals_str"]
            
        gt = f"{gt_count}_{animals_str}_prompt{prompt_id}" if prompt_id != 0 else f"{gt_count}_{animals_str}"
        
        if gt not in processed_cases:
            filtered_prompts.append(prompt)
    
    return filtered_prompts

def build_pipeline(name, gpu_idx):
    """Build FLUX pipeline for image generation."""
    device = torch.device("cuda", index=gpu_idx)
    pipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev", 
        torch_dtype=torch.bfloat16, 
        cache_dir="/workspace/cache/"
    )
    pipeline.to(device=device)
    return pipeline


def generate_image(prompts, pipeline, args):
    torch.manual_seed(args.seed)
    
    total_images = len(prompts) * args.batch_size

    rprint(f"[bold green]Starting image generation: {len(prompts)} prompts, {args.batch_size} images per prompt[/bold green]")
    
    with Progress(
        SpinnerColumn(),
        TextColumn("[progress.description]{task.description}"),
        BarColumn(),
        TaskProgressColumn(),
        TextColumn("[cyan]{task.fields[status]}[/cyan]"),
        transient=True,
    ) as progress:
        task = progress.add_task("Generation Progress", total=len(prompts), status="")
        
        for idx, prompt in enumerate(prompts):
            gt_count = prompt["gt_count"]

            try:
                animals_str = prompt["animals_str"].replace("s", "")
            except:
                try: 
                    animals_str = prompt['obj'].replace(" ", "_")
                except:
                    animals_str = prompt["objects_str"]
            prompt = prompt["prompt"]
            progress.update(task, status=f"Processing Prompt {idx}")

            gt = f"{gt_count}_{animals_str}"
                       
            # Check if already processed
            if any(f == f"case_{gt}_tmp" or f.startswith(f"case_{gt}_output") for f in os.listdir(args.output_dir)):
                progress.update(task, status=f"Skipping {gt} - already processed")
                progress.update(task, advance=1)
                continue
            
            # Create placeholder file
            placeholder_path = os.path.join(args.output_dir, f"case_{gt}_tmp")
            with open(placeholder_path, 'w') as f:
                f.write('placeholder')
            
            try:
                # Generate images
                progress.update(task, status=f"Generating images for {gt}...")
                images = pipeline(
                    prompt=prompt,
                    num_images_per_prompt=args.batch_size,
                    height=args.height,
                    width=args.width,
                    guidance_scale=args.cfg,
                ).images
                
                # Save generated images
                if os.path.exists(placeholder_path):
                    os.remove(placeholder_path)
                for i, image in enumerate(images):
                    image.save(os.path.join(args.output_dir, f"case_{gt}_output_{i}.png"))
                
            except Exception as e:
                progress.update(task, status=f"Failed to generate {gt}: {str(e)}")
            
            progress.update(task, status=f"Completed {gt}")
            progress.update(task, advance=1)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--name", type=str, default="v01_lora_r64_bs16_flux_count")
    parser.add_argument("--height", type=int, default=512) 
    parser.add_argument("--width", type=int, default=512) 
    parser.add_argument("--cfg", type=float, default=3.5)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--batch_size", type=int, default=10)
    parser.add_argument("--prompt_file", type=str, default="eval_prompt_one_animal_50.json")
    parser.add_argument("--gpu_id", type=int, default=0)
    parser.add_argument("--ckpt_step", type=str, default="checkpoint-1000")
    parser.add_argument("--folder", type=str, default="v01_lora_r16_bs16_flux_count")
    parser.add_argument("--target_class", type=str, default=None)
    args = parser.parse_args()

    ckpt_folder = f"/workspace/output/{args.folder}"
    args.ckpt_dir = os.path.join(ckpt_folder, args.ckpt_step)
    args.output_dir = f"/workspace/eval_output/{args.folder}/{args.ckpt_step}"

    args.prompt_file = f'/workspace/data/{args.prompt_file}'
    with open(args.prompt_file, 'r') as f:
        prompts = json.load(f)
    
    os.makedirs(args.output_dir, exist_ok=True)
    print(f'Prompt file: {args.prompt_file}')
    print(f'Output directory: {args.output_dir}')
    print(f'Using {len(prompts)} prompts')

    pipeline = build_pipeline("black-forest-labs/FLUX.1-dev", args.gpu_id)
    
    if args.ckpt_step != "checkpoint-0":
        pipeline.load_lora_weights(args.ckpt_dir)

    generate_image(prompts, pipeline, args)