import sys, os
import time

from rewards.rewards import multi_score,MultiScorer

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))


import torch
from PIL import Image
import numpy as np
from diffusers import StableDiffusion3Pipeline, SD3Transformer2DModel, DDIMScheduler, DDIMInverseScheduler
from pipeline_FLUX_with_log_prob import FluxTTRLPipeline
from diffusers import FluxTransformer2DModel
from transformers import BitsAndBytesConfig
import importlib

model_id = "black-forest-labs/FLUX.1-dev"


score_dict = {
            #"clipscore": 2.0,
            #"aesthetic": 5,
            # "jpeg_compressibility": 0.2,
            #"pickscore": 1.0,
            # "jpeg_incompressibility": 0.1,
             "imagereward": 1.0,
             #"hps": 1.0,
            #"black_loss": 1,
            # "white_loss": 5,
            # "contrast_loss": 0.1
        }


scorer = MultiScorer(device="cuda", score_dict=score_dict)

# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16
# )

# model_nf4 = FluxTransformer2DModel.from_pretrained(
#     model_id,
#     subfolder="transformer",
#     quantization_config=bnb_config,
#     torch_dtype=torch.bfloat16
# )

# pipe = FluxTTRLPipeline.from_pretrained(
#     model_id,
#     transformer=model_nf4,
#     torch_dtype=torch.bfloat16
# )

pipe = FluxTTRLPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16
)


pipe.transformer.enable_gradient_checkpointing()
pipe.enable_attention_slicing()
#pipe.enable_model_cpu_offload()

pipe = pipe.to("cuda")

prompt = 'A small astronaut in a reflective suit drifting in microgravity, holding an LED panel that shows the text Flow-TTRL, with Earth glowing softly in the background.'

generator = torch.Generator()
generator.manual_seed(42)

images = pipe(
    prompt,
    negative_prompt= "ugly,low resolution,blurry image,bad composition,disfigured,oversaturated",
    num_inference_steps=40,
    guidance_scale=3.5,
    generator=generator,
    noise_range=[1.5, 0.5],
    scale_factor=400,
    group_size=6,
    beta1=0.0002,
    beta2=0.002,
    score_dict=score_dict,
    scoring_fn=scorer,
    internal_reward_timestep=0.2,
    external_reward_timestep=0.5,
    reward_diff_threshold=0.0,
    RL_interation_num=2,
)
image = images.images[0]
image.save(f"demo/1.png")