import os
import re
import torch
import torch.distributed as dist
from pathlib import Path
from diffusers import FluxPipeline
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.utils import is_torch_xla_available
from torch.utils.data import Dataset, DistributedSampler
from safetensors.torch import load_file
import argparse
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import copy
import json

import hpsv2
from sample_flux import DualFluxPipeline
import json
from tqdm import tqdm  
from hpsv2.utils import root_path, hps_version_map
from fastvideo.utils.reward_utils import initialize_hps_model

if is_torch_xla_available():
    import torch_xla.core.xla_model as xm
    XLA_AVAILABLE = True
else:
    XLA_AVAILABLE = False


device = 'cuda'

def main(args):
    if args.mix_sampling_steps > 0:
        pipe = DualFluxPipeline.from_pretrained(
            args.flux_baseline_model_dir,
            torch_dtype=torch.bfloat16,
            use_safetensors=True
        ).to("cuda")
        
        pipe.load_new_model(args.model_path)
    else:
        pipe = FluxPipeline.from_pretrained(
            args.flux_baseline_model_dir,
            torch_dtype=torch.bfloat16,
            use_safetensors=True
        ).to("cuda")

        if not args.baseline:
            model_state_dict = load_file(args.model_path)
            pipe.transformer.load_state_dict(model_state_dict, strict=True)
            pipe.to("cuda")

    reward_model, preprocess_val, tokenizer = initialize_hps_model(args, 0)
    
    generator = torch.Generator(device='cuda')
    generator.manual_seed(args.seed)  
    os.makedirs(args.output_dir, exist_ok=True)
    print('output_dir is:', args.output_dir)

    good_prompts_path = ""
    output_path = ""
    with open(good_prompts_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    results = []
    score_differences = []
    
    with torch.no_grad():
        start = 0
        for item in tqdm(data, desc="Processing prompts"):
            prompt = item["prompt"]
            score_mean = item["score_mean"]
            i = item.get("i", -1)
            j = item.get("j", -1)
            group = item.get("group", 1)
            if args.mix_sampling_steps > 0:
                image = pipe(
                    prompt,
                    guidance_scale=3.5,
                    height=1024,
                    width=1024,
                    num_inference_steps=args.total_sampling_steps,
                    max_sequence_length=512,
                    generator=generator,
                    mix_sampling_steps=args.mix_sampling_steps
                ).images[0]
            else:
                image = pipe(
                    prompt,
                    guidance_scale=3.5,
                    height=1024,
                    width=1024,
                    num_inference_steps=args.total_sampling_steps,
                    max_sequence_length=512,
                    generator=generator,
                ).images[0]
            image.save(os.path.join(args.output_dir, f"{i}_{j}_{group}.jpg")) 
            print(f'saved {i}_{j}_{group}.jpg')
            image = preprocess_val(image).unsqueeze(0).to(device=device, non_blocking=True)
            text = tokenizer([prompt]).to(device=device, non_blocking=True)
            outputs = reward_model(image, text)
            image_features, text_features = outputs["image_features"], outputs["text_features"]
            logits_per_image = image_features @ text_features.T
            hps_score = torch.diagonal(logits_per_image)
            
            score_difference = hps_score - score_mean
            score_differences.append(score_difference)
            results.append({
                "prompt": prompt,
                "score_mean": score_mean,
                "i": i,
                "j": j,
                "group": group,
                "score_difference": float(score_difference)
            })
            start = start + 1
            # if start >= 30:
            #     break
    
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

    # 计算平均 score_difference
    avg_diff = float(sum(score_differences) / len(score_differences)) if score_differences else 0
    print(f"完成！平均 score_difference: {avg_diff:.4f}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Flux Inference for MixGRPO")
    parser.add_argument("--model_path", type=str,
                        help="Path to the MixGRPO model checkpoint")
    parser.add_argument("--prompts_file", type=str, default="./data/prompts_test.txt",
                        help="Path to the file containing prompts")
    parser.add_argument("--output_dir", type=str, default="./output_flux",
                        help="Directory to save generated images")
    parser.add_argument("--output_json", type=str, default="output_flux.json",
                        help="Path to save the output JSON file with metadata")
    parser.add_argument("--seed", type=int, default=42,
                        help="Seed for random number generation")
    parser.add_argument("--baseline", action='store_true', default=False,
                        help="Use baseline model settings")
    parser.add_argument("--mix_sampling_steps", type=int, default=-1,
                        help="Number of sampling steps of the MixGRPO model")
    parser.add_argument("--total_sampling_steps", type=int, default=50,
                        help="Total number of sampling steps")
    parser.add_argument("--flux_baseline_model_dir", type=str, default="./data/flux",)

    # add
    parser.add_argument("--style", type=str, default=None,
                        help="styles")
    parser.add_argument("--hps_path",
        type=str,
        default='',
        help="The path of hps model",
    )
    parser.add_argument("--hps_checkpoint_path",
        type=str,
        default='',
        help="The checkpoint path of hps model",
    )

    args = parser.parse_args()

    main(args)