#!/usr/bin/env python3

import logging
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
import config
from utils import call_model, llm_answer_to_img, image_answer_to_img

# Setup minimal logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')


def process_random_baseline(model, epoch):
    """Process single model for single epoch - random baseline."""
    # Parse model string and create safe filename
    parts = model.split("/", 1)
    provider = parts[0]
    model_name = parts[1] if len(parts) > 1 else "default"
    safe_model_name = model_name.replace("/", "_")

    output_dir = Path("./random_baseline")
    output_file = output_dir / f"{provider}_{safe_model_name}_{epoch}.png"

    if output_file.exists():
        logging.info(f"Skip: {model}+{epoch} (exists)")
        return

    # Choose appropriate instruction based on model type
    image_models = ["gpt-image-1", "gemini-2.5-flash-image-preview", "grok-2-image"]
    if any(img_model in model for img_model in image_models):
        prompt_instr = config.RANDOM_IMAGE_INSTR
    else:
        prompt_instr = config.RANDOM_LLM_INSTR

    prompt = config.RULES + prompt_instr
    images = []  # Empty list - no input images

    # Retry logic for model call and conversion
    for attempt in range(config.MAX_RETRIES):
        try:
            # Call model with no images
            response = call_model(model, images, prompt)
            if not response:
                if attempt == config.MAX_RETRIES - 1:
                    logging.error(f"Model call failed after {config.MAX_RETRIES} attempts for {model}+{epoch}")
                    return "error"
                logging.warning(f"Model call attempt {attempt + 1} failed for {model}+{epoch}, retrying...")
                continue

            # Convert to image
            image_models = ["gpt-image-1", "gemini-2.5-flash-image-preview", "grok-2-image"]
            if any(img_model in model for img_model in image_models):
                result_img = image_answer_to_img(response)
            else:
                result_img = llm_answer_to_img(response)

            if not result_img:
                if attempt == config.MAX_RETRIES - 1:
                    logging.error(f"Could not convert response after {config.MAX_RETRIES} attempts for {model}+{epoch}")
                    return "error"
                logging.warning(f"Conversion attempt {attempt + 1} failed for {model}+{epoch}, retrying...")
                continue

            # Save
            output_dir.mkdir(parents=True, exist_ok=True)
            result_img.save(output_file)
            logging.info(f"Saved: {output_file}")
            return None  # Success

        except Exception as e:
            if attempt == config.MAX_RETRIES - 1:
                logging.error(f"Processing failed after {config.MAX_RETRIES} attempts for {model}+{epoch}: {e}")
                return "error"
            logging.warning(f"Processing attempt {attempt + 1} failed for {model}+{epoch}: {e}, retrying...")
            continue

def main():
    print(f"Processing {len(config.MODELS)} models for {config.EPOCHS} epochs (random baseline)")

    tasks = []
    for epoch in range(1, config.EPOCHS + 1):
        for model in config.MODELS:
            tasks.append((model, epoch))

    # Run in parallel with statistics tracking
    successful = 0
    failed = 0
    start_time = datetime.now()

    with ThreadPoolExecutor(max_workers=config.NUM_WORKERS) as executor:
        futures = {executor.submit(process_random_baseline, model, epoch): (model, epoch)
                   for model, epoch in tasks}

        for future in as_completed(futures):
            model, epoch = futures[future]
            try:
                result = future.result()
                if result is None:  # Indicates success
                    successful += 1
                else:
                    failed += 1
            except Exception as e:
                logging.error(f"Task {model}+{epoch} failed: {e}")
                failed += 1

    # Print summary
    end_time = datetime.now()
    duration = end_time - start_time
    print(f"\n{'='*60}")
    print("RANDOM BASELINE EXECUTION SUMMARY")
    print(f"{'='*60}")
    print(f"Total tasks: {len(tasks)}")
    print(f"Successful: {successful}")
    print(f"Failed: {failed}")
    print(f"Duration: {duration}")
    print(f"{'='*60}")

if __name__ == "__main__":
    main()