import torch
import numpy as np
import os
import PIL.Image
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
import torchvision
import json
import argparse
import copy
import random
from typing import List, Dict
import uuid
from pathlib import Path
import sys

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../'))
import sys
sys.path.insert(0, project_root)
from NoiseAR.run_try import NoiseARNet

# Import Qwen-Image model
from diffusers import DiffusionPipeline
from modelscope import snapshot_download

def seed_all(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="/home/tmp/***/official_run_grpo_noise_exp4/checkpoint")
parser.add_argument("--data_path", type=str, default="/home/tmp/geneval/prompts/evaluation_metadata.jsonl", help="Path to GenEval metadata file")
parser.add_argument("--reasoning_prompt_path", type=str, default="/home/tmp/***/data/prompt/reasoning_prompt.txt")
parser.add_argument("--save_dir", type=str, default='output_official_run_grpo_noise_exp4_checkpoint', help="Path to save GenEval standard format output")
parser.add_argument("--num_generation", type=int, default=4)
parser.add_argument("--device", type=str, default="cuda", help="Device to use for inference")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--start_idx", type=int, default=0, help="Start index for processing prompts")
parser.add_argument("--end_idx", type=int, default=None, help="End index for processing prompts (exclusive)")
args = parser.parse_args()

# Set seed
seed_all(args.seed)

from diffusers import StableDiffusionXLPipeline
sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    variant="fp16",
    use_safetensors=True,
    torch_dtype=torch.float16,
).to('cpu')
print("SDXL model loaded successfully for inference!")


# Qwen-Image model
model_dir = snapshot_download(
    'Qwen/Qwen-Image',
    cache_dir='~/.cache/modelscope',
    local_files_only=False
)
pipe = DiffusionPipeline.from_pretrained(
    model_dir,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
)
print("Qwen-Image model loaded successfully for inference!")
pipe.enable_model_cpu_offload()


# Load trained NoiseAR model
model_file_path = os.path.join(args.model_path, "pytorch_model.bin")
state_dict = torch.load(model_file_path, map_location='cpu', weights_only=True)
print(f"Loading model weights from: {model_file_path}")

noise_ar_net = NoiseARNet(
    patch_size=32,
    d_model=2048,
    d_ff=2048 * 2,
    n_heads=1,
    n_layers=1,
    dropout=0.15,
    pretrained_path="",
    pipeline="SDXL",
).to('cpu')
noise_ar_net.load_state_dict(state_dict)
print("Trained NoiseAR model loaded successfully!")


# Load GenEval format prompts
prompt_list = []
with open(args.data_path, 'r') as f:
    for line in f:
        prompt_list.append(json.loads(line.strip()))

with open(args.reasoning_prompt_path, 'r') as f:
    cot_prompt = f.read().strip()

# Create output directory structure for GenEval standard format
output_dir = Path(args.save_dir)
output_dir.mkdir(exist_ok=True, parents=True)

def get_caption_height(text, font, img_width, draw):
    """Calculate the height needed for given text at specified width"""
    words = text.split()
    lines = []
    current_line = ""

    for word in words:
        test_line = current_line + " " + word if current_line else word
        text_width = draw.textlength(test_line, font=font)

        if text_width < img_width - 20:
            current_line = test_line
        else:
            lines.append(current_line)
            current_line = word

    if current_line:
        lines.append(current_line)

    try:
        font_size = font.size
    except:
        font_size = font.getsize('X')
        font_size = max(font_size)
    line_height = font_size + 4
    return len(lines) * line_height + 20

def create_geneval_format_output(visual_img, answer_list, base_save_dir, prompt_text, num_generation, sample_id, metadata):
    """Create GenEval standard format output with proper directory structure"""

    # Create sample directory with zero-padded ID
    sample_dir = os.path.join(base_save_dir, f"{sample_id:05d}")
    os.makedirs(sample_dir, exist_ok=True)

    # Create samples subdirectory
    samples_dir = os.path.join(sample_dir, "samples")
    os.makedirs(samples_dir, exist_ok=True)

    # Save individual images
    for i in range(num_generation):
        img = Image.fromarray(visual_img[i])
        img_filename = f"{i:05d}.png"
        img_path = os.path.join(samples_dir, img_filename)
        img.save(img_path)

    # Create grid WITHOUT captions (just images)
    grid_images = []
    for i in range(num_generation):
        img = Image.fromarray(visual_img[i])
        img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1)
        grid_images.append(img_tensor)

    # Create and save grid (no text)
    nrow = int(np.ceil(np.sqrt(num_generation)))
    grid = torchvision.utils.make_grid(grid_images, nrow=nrow)
    grid = grid.permute(1, 2, 0).numpy()
    grid = grid.astype(np.uint8)

    grid_path = os.path.join(sample_dir, "grid.png")
    PIL.Image.fromarray(grid).save(grid_path)

    # Create metadata.jsonl file (only original metadata, no extra fields)
    metadata_path = os.path.join(sample_dir, "metadata.jsonl")
    with open(metadata_path, 'w') as f:
        f.write(json.dumps(metadata) + '\n')

    return sample_dir

def _pack_latents(latents, batch_size, num_channels_latents, height, width):
    latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
    latents = latents.permute(0, 2, 4, 1, 3, 5)
    latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
    return latents

@torch.inference_mode()
def generate_with_noise_ar(
    noise_ar_net: NoiseARNet,
    sdxl_pipe,
    qwen_image_pipe,
    prompt_text: str,
    sample_id: int,
    metadata: Dict,
    num_generation: int = 4,
    img_size: int = 1024,  # Changed to match training code
):
    """Generate images using trained NoiseAR model with GenEval format output"""

    device = torch.device(args.device)

    answer_list = []
    for i in range(num_generation):
        answer = f"Raw prompt: {prompt_text}"
        answer_list.append(answer)

    # raw prompt -> embedding -> pred_noise
    sdxl_pipe = sdxl_pipe.to(device)
    with torch.no_grad():  # sdxl_pipe is frozen, no gradients needed
        prompt_embeds, _, _, _ = sdxl_pipe.encode_prompt(prompt_text, device=device)
    sdxl_pipe = sdxl_pipe.to('cpu')

    assert prompt_embeds.shape == (1, 77, 2048)
    noise_ar_net = noise_ar_net.float().to(device)
    pred_noise = noise_ar_net(text_emb=prompt_embeds.clone().float())

    width, height = 1024, 1024
    vae_scale_factor = 8
    num_channels_latents = 16
    batch_size = 1
    latent_height = 2 * (int(height) // (vae_scale_factor * 2))
    latent_width = 2 * (int(width) // (vae_scale_factor * 2))
    latent_shape = (batch_size, 1, num_channels_latents, latent_height, latent_width)

    visual_img = []
    with torch.no_grad():
        torch.cuda.empty_cache()

        # Shape adjustment for predicted noise (matching training code)
        pred_noise_tmp = pred_noise.clone().detach().unsqueeze(0)
        pred_noise = pred_noise_tmp.repeat(1, 1, 4, 1, 1)
        assert pred_noise.shape == latent_shape
        latents = _pack_latents(pred_noise.to(torch.bfloat16), batch_size, num_channels_latents, latent_height, latent_width)

        del pred_noise, pred_noise_tmp
        torch.cuda.empty_cache()

        print(f'Generating image with prompt: {prompt_text}')

        generated_image = qwen_image_pipe(
            prompt=prompt_text,
            negative_prompt=" ",
            width=width,
            height=height,
            num_inference_steps=50,
            true_cfg_scale=4.0,
            generator=torch.Generator(device="cpu").manual_seed(42),
            latents=latents
        ).images[0]

        img_array = np.array(generated_image)

        for i in range(num_generation):
            visual_img.append(img_array)

        del latents, generated_image
        torch.cuda.empty_cache()

    sample_dir = create_geneval_format_output(
        visual_img, answer_list, args.save_dir, prompt_text, num_generation, sample_id, metadata
    )

    return visual_img, answer_list, sample_dir

def main():
    """Main evaluation loop"""
    # Determine the range of prompts to process
    start_idx = args.start_idx
    end_idx = args.end_idx if args.end_idx is not None else len(prompt_list)

    # Validate indices
    start_idx = max(0, start_idx)
    end_idx = min(len(prompt_list), end_idx)

    if start_idx >= end_idx:
        print(f"Error: start_idx ({start_idx}) >= end_idx ({end_idx})")
        return

    # Get the subset of prompts to process
    subset_prompts = prompt_list[start_idx:end_idx]

    print(f"Starting GenEval evaluation with {len(subset_prompts)} prompts (indices {start_idx}-{end_idx-1})")
    print(f"Total prompts available: {len(prompt_list)}")
    print(f"Output directory: {args.save_dir}")

    for idx, metadata in enumerate(tqdm(subset_prompts, desc="Processing prompts")):
        # Calculate the actual index in the original list
        actual_idx = start_idx + idx
        sample_id = metadata.get('sample_id', actual_idx)
        prompt_text = metadata.get('prompt', '')   # prompt_text  'a photo of a bench'

        print(f"\nProcessing sample {sample_id} (index {actual_idx}): {prompt_text}")

        # Generate images and reasoning
        try:
            visual_img, answer_list, sample_dir = generate_with_noise_ar(
                noise_ar_net=noise_ar_net,
                sdxl_pipe=sdxl_pipe,
                qwen_image_pipe=pipe,
                prompt_text=prompt_text,
                sample_id=sample_id,
                metadata=metadata,
                num_generation=args.num_generation,
            )

            print(f"Generated {len(visual_img)} images for sample {sample_id}")
            print(f"GenEval format output saved to: {sample_dir}")

        except Exception as e:
            print(f"Error processing sample {sample_id}: {e}")
            continue

    print(f"\nProcessed prompts {start_idx}-{end_idx-1}")
    print(f"Results saved to: {args.save_dir}")

if __name__ == "__main__":
    main()
