import sys
import os
from pathlib import Path
import torch
import pandas as pd
from diffusers import DiffusionPipeline
from PIL import Image
from tqdm import tqdm
from diffusers import FluxPipeline


class FLUXGenerator:
    def __init__(self, device):
        self.device = device
        self.pipeline = None
        self.setup_pipeline()

    def setup_pipeline(self):
        """Initialize the SDXL pipeline on specified device."""
        self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(self.device)
        print(f"Model loaded successfully on {self.device}")

    def generate_image(self, prompt, output_path):
        """Generate a single image using SDXL."""
        try:
            generated_image = self.pipeline(prompt, height=512, width=512, guidance_scale=3.5, num_inference_steps=50, max_sequence_length=512).images[0]
            generated_image.save(output_path)
            return True
        except Exception as e:
            print(f"Error generating image: {str(e)}")
            return False


    
class SDXLGenerator:
    def __init__(self, device):
        self.device = device
        self.pipeline = None
        self.setup_pipeline()

    def setup_pipeline(self):
        """Initialize the SDXL pipeline on specified device."""
        self.pipeline = DiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            torch_dtype=torch.float16, 
            use_safetensors=True
           
        ).to(self.device)
        print(f"Model loaded successfully on {self.device}")

    def generate_image(self, prompt, output_path):
        """Generate a single image using SDXL."""
        try:
            with torch.inference_mode():
                generated_image = self.pipeline(prompt=prompt).images[0]
            generated_image.save(output_path)
            return True
        except Exception as e:
            print(f"Error generating image: {str(e)}")
            return False

def setup_directories(csv_path, expname, modelname):
    """Create output directories for different caption types and temperatures."""
   
    base_dir = Path(f"flickr_train-{expname}-{modelname}-images/")
    # Create directories 
    dirs = {
        # Base caption directories
        "0.75": base_dir / "0.75",
        "0.85": base_dir / "0.85",
        "0.95": base_dir / "0.95"
       
    }
    
    for dir_path in dirs.values():
        print("creating ", dir_path)
        dir_path.mkdir(parents=True, exist_ok=True)

    return dirs

def extract_caption(text):
    """Extract caption from the model output text."""
    try:
        caption = text.lower().split('assistant')[-1].split('caption:')[-1]
        caption = caption.replace(':', '').replace('**', '').replace('"', '').strip()
        caption = caption.replace("enhance ", "generate ")
        return caption
    except Exception as e:
        print(f"Error extracting caption: {str(e)}")
        return ""

import ast
def generate_images(start_idx, end_idx, modelname, csv_path, expname, device):
    """Generate images for base_cap and new_cap across different temperatures."""
    # Load CSV data
    df = pd.read_csv(csv_path)

    # Setup output directories
    output_dirs = setup_directories(csv_path, expname, modelname)

    # Initialize generator
    if modelname == 'sdxl':
        generator = SDXLGenerator(device)
    elif modelname == 'flux':
         generator = FLUXGenerator(device)
    else:
        print("Wrong modelname given")
        return

    # Process each row in the range
    for idx in tqdm(range(start_idx, min(end_idx + 1, len(df))), desc=f"Generating images on {device}"):
        row = df.iloc[idx]
        filename = row['filename']
        print(row.keys())
        # Define prompt dictionaries for both base_cap and new_cap
        prompts = {
            "0.75": extract_caption(row[f'output_temp_{expname}_0.75']),
            "0.85": extract_caption(row[f'output_temp_{expname}_0.85']),
            "0.95": extract_caption(row[f'output_temp_{expname}_0.95']),
        }

        # Generate images for base_cap prompts
        for temp, prompt in prompts.items():
            try:
                if not prompt:
                    continue
                    
                output_path = output_dirs[f"{temp}"] / filename
                
                # Skip if image already exists
                if output_path.exists():
                    continue

                print(f"Generating base_cap image for temp {temp}: {prompt}")
                #print(prompt)
                success = generator.generate_image(
                    prompt=prompt,
                    output_path=output_path
                )

                if not success:
                    print(f"Failed to generate base_cap image for index {idx} with temperature {temp}")

            except Exception as e:
                print(f"Error processing base_cap index {idx} for temp {temp}: {str(e)}")
                continue

        # Clear CUDA cache periodically
        if idx % 5 == 0:
            torch.cuda.empty_cache()

if __name__ == "__main__":
    if len(sys.argv) != 7:
        print("Usage: python generate.py start_idx end_idx modelname device csv_path expname")
        sys.exit(1)

    start_idx = int(sys.argv[1])
    end_idx = int(sys.argv[2])
    modelname = str(sys.argv[3])
    device = sys.argv[4]
    csv_path = sys.argv[5]
    expname = sys.argv[6]
    
    generate_images(start_idx, end_idx, modelname, csv_path, expname, device)