import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline
import os
import argparse
from utils import (
    load_prompts, 
    remove_unique_token,
    upload_file_to_gcs, 
    setup_logger)
from evaluation_metrics import DINO_score, CLIP_I_score, CLIP_T_score, RewardModel
from tqdm import tqdm
import json
import sys

def convert(o):
    if isinstance(o, np.generic):
        return o.item()

def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")

    parser.add_argument(
        "--reference_data_dir",
        type=str,
        default="../dreambooth/dataset/dog",
        help="The directory where the reference images are stored.",
    )
    parser.add_argument(
        "--savepath",
        type=str,
        default="path-to-save-model/rpo/",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--class_token",
        type=str,
        default="dog",
        help="Class token for the specific subject.",
    )
    parser.add_argument(
        "--subject",
        type=str,
        default="dog",
        help="The subject folder of the dreambooth dataset.",
    )
    parser.add_argument(
        "--algo",
        type=str,
        default="rpo",
        help="The algorithm used for training the model.",
    )
    parser.add_argument(
        "--reward_lambda",
        type=float,
        default=0.3,
        help="The lambda used for training the model"
    )
    parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
    parser.add_argument(
        "--resolution",
        type=int,
        default=512,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument("--bucket", type=str, default="dpo_booth_bucket", help="Google Cloud Bucket to store the data.")

    args = parser.parse_args()
    return args

def main():
    args = parse_args()
    # ----------------------- Loading Models ----------------------- #
    def load_checkpoint():
        # Load saved models
        pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
            args.savepath, 
        )
        return pipeline, params
    
    local_path = "logs/sampled_images/"
    os.makedirs(local_path, exist_ok=True)
    logger = setup_logger()

    def handle_exception(exc_type, exc_value, exc_traceback):
        """Exception handler that logs the traceback and exception message."""
        if issubclass(exc_type, KeyboardInterrupt):
            # Allow user to interrupt the program without logging it as an error
            sys.__excepthook__(exc_type, exc_value, exc_traceback)
            return
        
        logger.error("Uncaught exception",
                    exc_info=(exc_type, exc_value, exc_traceback))

    sys.excepthook = handle_exception

    pipeline, params = load_checkpoint()
    live_subjects = ["dog", "dog2", "dog3", "dog5", "dog6", "dog7", "dog8", "cat", "cat2"]
    if args.subject not in live_subjects:
        live = False
    else:
        live = True
    test_prompts = load_prompts("[V]", args.class_token, live=live)
    dino_score, clip_i_score, clip_t_score, rewards = [], [], [], []
    reward_model = RewardModel(args.reference_data_dir)
    test_rng = jax.random.PRNGKey(0)
    num_samples = jax.device_count()

    for test_prompt in tqdm(test_prompts):
        # ----------------------- Inference ----------------------- #
        local_path = f"logs/sampled_images/{args.algo}/{args.subject}/"
        os.makedirs(local_path, exist_ok=True)

        prompt = num_samples * [test_prompt]
        prompt_ids = pipeline.prepare_inputs(prompt)
        inference_params = replicate(params)
        test_seed = jax.random.split(test_rng, num_samples)
        prompt_ids = shard(prompt_ids)
        images = pipeline(prompt_ids, inference_params, test_seed, jit=True).images
        images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

        # ----------------------- Evaluation ----------------------- #
        clean_test_prompt = remove_unique_token(test_prompt, "[V]")
        dino = DINO_score(args.reference_data_dir, images)
        clip_i = CLIP_I_score(args.reference_data_dir, images)
        clip_t = CLIP_T_score([clean_test_prompt] * len(images), images)
        reward = reward_model.get_reward(images, [clean_test_prompt] * len(images))

        dino_score.append(dino)
        clip_i_score.append(clip_i)
        clip_t_score.append(clip_t)
        rewards.append(reward.mean())
        print(f"Current dino: {dino}")
        print(f"Current clip-i: {clip_i}")
        print(f"Current clip-t: {clip_t}")
        print(f"Current reward: {reward.mean()}")

        for i, image in enumerate(images):
            image_filename = local_path + f"{test_prompt}-{i}.jpg"
            image.save(image_filename)
            # ----------------------- Upload images to Google Cloud ----------------------- #
            upload_file_to_gcs(image_filename, args.bucket, f"lambda={args.reward_lambda}/generated_images/{args.algo}/{args.subject}/{test_prompt}-{i}.jpg")
                    
        print(f"Inference images saved to {local_path}")


    print(f"Results for {args.subject}:")
    print(f"DINO Score: {np.mean(dino_score)}")
    print(f"CLIP I Score: {np.mean(clip_i_score)}")
    print(f"CLIP T Score: {np.mean(clip_t_score)}")
    print(f"Reward: {np.mean(rewards)}")

    results = {
        "DINO Score": np.mean(dino_score),
        "CLIP I Score": np.mean(clip_i_score),
        "CLIP T Score": np.mean(clip_t_score),
        "Reward": np.mean(rewards),
    }

    file_dir = "logs/results/" + f"{args.algo}/"
    os.makedirs(file_dir, exist_ok=True)
    filename = file_dir + f"{args.subject}.json"
    with open(filename, "w") as f:
        json.dump(results, f, default=convert, indent=4)

    blob_destination = f"lambda={args.reward_lambda}/experiments/{args.algo}/{args.subject}.json"
    upload_file_to_gcs(filename, args.bucket, blob_destination)
    upload_file_to_gcs("logs/error.log", args.bucket, f"lambda={args.reward_lambda}/experiments/{args.algo}/error.log")

if __name__ == "__main__":
    main()