import os
import torch
import lightning as L
from typing import List, Sequence
from datasets import Dataset
from datasets import load_dataset
from tqdm import tqdm
import csv
import time
from omegaconf import DictConfig
import hydra

DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data')

def load_texts(file_path: str | Sequence[str]) -> List[str]:
    if isinstance(file_path, str):
        file_path = [file_path]
    texts = []
    for path in file_path:
        with open(path, 'r', encoding='utf-8') as f:
            texts.extend([line.strip() for line in f if line.strip()])
    return texts


def gen(
        model,
        data_path: str | Sequence[str],
        img_size: int,
        num_samples: int,
        output_dirs: Sequence[str],
        start: int,
        end: int,
        **kwargs):
    """
    Generate images from text prompts. Save num_samples images per prompt in output_dirs,
    with filenames formatted as prompt+six_digit_number, e.g., a green bench and a blue bowl_000000.png
    """
    if isinstance(data_path, str):
        data_path = [data_path]
    if isinstance(output_dirs, str):
        output_dirs = [output_dirs]
    assert len(data_path) == len(output_dirs), "length of data_path and output_dirs must be the same"

    for path, out_dir in zip(data_path, output_dirs):
        os.makedirs(out_dir, exist_ok=True)
        texts = load_texts(path)
        texts = texts[start:end]
        print(f"Loaded {len(texts)} texts from {path}, generating images to {out_dir}")
        for text in tqdm(texts, desc=f"Generating images for {path}"):
            images = model.generate(
                text, 
                num_images_per_prompt=num_samples,
                output_type="pil",
                height=img_size, 
                width=img_size,
                **kwargs).images
            idx = 0
            for image in images:
                safe_prompt = text.replace('/', '_').replace('\\', '_').replace(':', '_')
                filename = f"{safe_prompt}_{idx:06d}.png"
                save_path = os.path.join(out_dir, filename)
                if hasattr(image, 'save'):
                    image.save(save_path)
                else:
                    from PIL import Image
                    img = Image.fromarray(image)
                    img.save(save_path)
                idx += 1
        print(f"Generated images saved to {out_dir}")

@hydra.main(version_base=None, config_path="../configs")
def main(cfg: DictConfig):
    other_kwargs = cfg.get('other_kwargs', {})
    L.seed_everything(cfg.seed)
    from models.noxeye import NoxEyePipeline
    model = NoxEyePipeline(**cfg.model_kwargs)
    model.to("cuda")
    print(f"Loaded model to device cuda")
    gen(
        model=model,
        data_path=cfg.data_path,
        img_size=cfg.img_size,
        num_samples=cfg.num_samples,
        output_dirs=cfg.output_dirs,
        start=cfg.start,
        end=cfg.end,
        **other_kwargs
    )

if __name__ == "__main__":
    main()