import argparse
import logging
import math
import os
import shutil
from pathlib import Path
import json
import time

import jax
import jax.numpy as jnp
import numpy as np
import optax
import torch
import torch.utils.checkpoint
import transformers
from flax import jax_utils
from flax.training import train_state
from flax.training.common_utils import shard
from huggingface_hub.utils import insecure_hashlib
from huggingface_hub import create_repo, upload_folder
from jax.experimental.compilation_cache import compilation_cache as cc

from tqdm.auto import tqdm
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
from matplotlib import pyplot as plt

from diffusers import (
    FlaxAutoencoderKL,
    FlaxDDIMScheduler,
    FlaxStableDiffusionPipeline,
    FlaxUNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker

from config import common_args
from google.cloud import storage

from rpo_datasets import PreferenceDataset

from utils import (
    upload_file_to_gcs, 
    remove_unique_token,
    clean_subject,
    training_prompts,
)

# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))

logger = logging.getLogger(__name__)

def get_params_to_save(params):
    return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))

def main():
    parser = argparse.ArgumentParser(description="Training")
    common_args.add_args(parser)
    args = parser.parse_args()

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        transformers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()

    if args.seed is not None:
        set_seed(args.seed)

    rng = jax.random.PRNGKey(args.seed)

    # ------------------------ Loading pipeline ------------------------ #

    pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
        args.pretrained_model_name_or_path, safety_checker=None, revision=args.revision
    )
    print("Pipeline loaded!")

    # ------------------------ Generate sample images from pretrained model ------------------------ #
    generated_images_dir = args.generated_data_dir
    if not os.path.exists(generated_images_dir):
        os.makedirs(generated_images_dir)

    sub_folders = [name for name in os.listdir(Path(generated_images_dir)) 
                   if os.path.isdir(os.path.join(Path(generated_images_dir), name))]
    num_genreated_folders = len(sub_folders)

    reference_images_dir = args.reference_data_dir
    num_prompts = len(list(Path(reference_images_dir).iterdir()))
    subject = clean_subject(args.subject)
    if subject not in ["dog", "cat"]:
        live = False
    else:
        live = True

    generated_prompts = training_prompts("[V]", args.class_token, live)
    print(f"training prompts = {generated_prompts}")
    if num_genreated_folders < num_prompts:
        pipeline.set_progress_bar_config(disable=True)

        num_samples = jax.local_device_count()
        num_new_images = (num_prompts - num_genreated_folders) * num_samples
        logger.info(f"Number of generated images to sample: {num_new_images}.")
        
        for i in tqdm(
            range(len(generated_prompts)), 
            desc="Generating images", 
            disable=not jax.process_index() == 0
        ):  
            text = generated_prompts[i]
            gen_image_save_dir = os.path.join(generated_images_dir, text)
            os.makedirs(gen_image_save_dir, exist_ok=True)
            prompt_ids = pipeline.prepare_inputs([text] * num_samples)
            prompt_ids = shard(prompt_ids)
            p_params = jax_utils.replicate(pipeline_params)
            rng = jax.random.split(rng)[0]
            sample_rng = jax.random.split(rng, jax.device_count())
            images = pipeline(prompt_ids, p_params, sample_rng, jit=True).images
            images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
            images = pipeline.numpy_to_pil(np.array(images))

            for image in images:
                hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
                image_filename = gen_image_save_dir + f"/{hash_image}.jpg"
                image.save(image_filename)

    print("Finished generating images!")

    # Handle the repository creation
    if jax.process_index() == 0:
        if args.savepath is not None:
            os.makedirs(args.savepath, exist_ok=True)

        if args.push_to_hub:
            repo_id = create_repo(
                repo_id=args.hub_model_id or Path(args.savepath).name, exist_ok=True, token=args.hub_token
            ).repo_id

    # Load the tokenizer and add the placeholder token as a additional special token
    if args.tokenizer_name:
        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
    elif args.pretrained_model_name_or_path:
        tokenizer = CLIPTokenizer.from_pretrained(
            args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
        )
    else:
        raise NotImplementedError("No tokenizer specified!")

    # ----------------------- Load the dataset ------------------------ #
    train_dataset = PreferenceDataset(
        reference_data_root=args.reference_data_dir,
        generated_data_root=args.generated_data_dir,
        prompt=args.prompt,
        desc_prompts=generated_prompts,
        tokenizer=tokenizer,
        size=args.resolution,
        center_crop=args.center_crop,
        lambda_=0,
    )

    def collate_fn(examples):
        input_ids = [example["prompt_ids"] for example in examples]
        input_ids += [example["desc_prompt_ids"] for example in examples]
        pixel_values = [example["pixel_values"] for example in examples]
        labels = [example["labels"] for example in examples]

        input_ids = tokenizer.pad(
            {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
        ).input_ids

        pixel_values = torch.stack(pixel_values)
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

        labels = torch.stack(labels)

        batch = {
            "input_ids": input_ids,
            "pixel_values": pixel_values,
            "labels": labels,
        }
        batch = {k: v.numpy() for k, v in batch.items()}
        return batch

    total_train_batch_size = args.train_batch_size * jax.local_device_count()
    if len(train_dataset) < total_train_batch_size:
        raise ValueError(
            f"Training batch size is {total_train_batch_size}, but your dataset only contains"
            f" {len(train_dataset)} images. Please, use a larger dataset or reduce the effective batch size. Note that"
            f" there are {jax.local_device_count()} parallel devices, so your batch size can't be smaller than that."
        )

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=total_train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True
    )

    weight_dtype = jnp.float32
    if args.mixed_precision == "fp16":
        weight_dtype = jnp.float16
    elif args.mixed_precision == "bf16":
        weight_dtype = jnp.bfloat16

    vae_arg, vae_kwargs = (args.pretrained_model_name_or_path, {"subfolder": "vae", "revision": args.revision})

    # Load models and create wrapper for stable diffusion
    text_encoder = FlaxCLIPTextModel.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="text_encoder",
        dtype=weight_dtype,
        revision=args.revision,
    )
    vae, vae_params = FlaxAutoencoderKL.from_pretrained(
        vae_arg,
        dtype=weight_dtype,
        **vae_kwargs,
    )
    unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="unet",
        dtype=weight_dtype,
        revision=args.revision,
    )
    # Copy the parameters to have a reference
    ref_unet_params = jax.tree_map(lambda x: x.copy(), unet_params)

    # Optimization
    if args.scale_lr:
        args.learning_rate = args.learning_rate * total_train_batch_size

    constant_scheduler = optax.constant_schedule(args.learning_rate)

    adamw = optax.adamw(
        learning_rate=constant_scheduler,
        b1=args.adam_beta1,
        b2=args.adam_beta2,
        eps=args.adam_epsilon,
        weight_decay=args.adam_weight_decay,
    )

    optimizer = optax.chain(
        optax.clip_by_global_norm(args.max_grad_norm),
        adamw,
    )

    unet_state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)
    text_encoder_state = train_state.TrainState.create(
        apply_fn=text_encoder.__call__, params=text_encoder.params, tx=optimizer
    )

    noise_scheduler, noise_scheduler_state = FlaxDDIMScheduler.from_pretrained(
        "stabilityai/stable-diffusion-2",
        subfolder="scheduler",
    )

    # Initialize our training
    train_rngs = jax.random.split(rng, jax.local_device_count())

    def train_step(unet_state, text_encoder_state, vae_params, batch, train_rng):
        dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)

        if args.train_text_encoder:
            params = {"text_encoder": text_encoder_state.params, "unet": unet_state.params}
        else:
            params = {"unet": unet_state.params}

        def compute_loss(params):
            # pixel_values is of shape (N, 2 * C, H, W)
            # reshape it to (2 * N, C, H, W)
            feed_pixel_values = jnp.concatenate(
                jnp.split(batch["pixel_values"], 2, axis=1), axis=0
            )
            # Convert images to latent space
            vae_outputs = vae.apply(
                {"params": vae_params}, feed_pixel_values, deterministic=True, method=vae.encode
            )
            latents = vae_outputs.latent_dist.sample(sample_rng)
            # (2 * N, H, W, C) -> (2 * N, C, H, W)
            latents = jnp.transpose(latents, (0, 3, 1, 2))
            latents = latents * vae.config.scaling_factor

            # Sample noise that we'll add to the latents
            noise_rng, timestep_rng = jax.random.split(sample_rng)
            noise = jax.random.normal(noise_rng, latents.shape, dtype=latents.dtype)
            # Sample a random timestep for each image
            bsz = latents.shape[0]
            timesteps = jax.random.randint(
                timestep_rng,
                (bsz,),
                0,
                noise_scheduler.config.num_train_timesteps,
            )

            # Make timesteps and noise same for both references and generated images
            split_noise = jnp.split(noise, 2)[0]
            noise = jnp.tile(split_noise, (2, 1, 1, 1))
            split_timesteps = jnp.split(timesteps, 2)[0]
            timesteps = jnp.tile(split_timesteps, 2)

            # Add noise to the latents according to the noise magnitude at each timestep
            noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)

            # Get the text embedding for conditioning
            if args.train_text_encoder:
                encoder_hidden_states = text_encoder_state.apply_fn(
                    batch["input_ids"], params=params["text_encoder"], dropout_rng=dropout_rng, train=True
                )[0]
            else:
                encoder_hidden_states = text_encoder(
                    batch["input_ids"], params=text_encoder_state.params, train=False
                )[0]

            # Predict the noise residual
            model_pred = unet.apply(
                {"params": params["unet"]}, noisy_latents, timesteps, encoder_hidden_states, train=True
            ).sample
            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
            
            # Get the difference for learned model
            model_losses = jnp.mean((model_pred - target) ** 2, axis=(1, 2, 3))
            model_losses_ref, model_losses_gen = jnp.split(model_losses, 2, axis=0)

            # Get the reference prediction
            ref_model_pred = unet.apply(
                {"params": ref_unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False
            ).sample
            ref_losses = jnp.mean((ref_model_pred - target) ** 2, axis=(1, 2, 3))
            ref_losses_ref, ref_losses_gen = jnp.split(ref_losses, 2, axis=0)

            # Compute the loss
            kl_diff = (ref_losses_ref - model_losses_ref) - (ref_losses_gen - model_losses_gen)

            labels = batch["labels"]
            similar_loss = jnp.mean(model_losses_ref)
            preference_loss = -jnp.mean(
                labels * jax.nn.log_sigmoid(args.beta * kl_diff) \
                + (1 - labels) * jax.nn.log_sigmoid(-args.beta * kl_diff)
            )
            loss =  similar_loss + preference_loss
            info = {"similar loss": similar_loss, "preference loss": preference_loss, "kl_diff": kl_diff.mean()}

            return loss, info

        grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
        (loss, info), grad = grad_fn(params)
        grad = jax.lax.pmean(grad, "batch")

        new_unet_state = unet_state.apply_gradients(grads=grad["unet"])
        if args.train_text_encoder:
            new_text_encoder_state = text_encoder_state.apply_gradients(grads=grad["text_encoder"])
        else:
            new_text_encoder_state = text_encoder_state

        info["loss"] = loss
        info = jax.lax.pmean(info, axis_name="batch")

        return new_unet_state, new_text_encoder_state, info, new_train_rng

    def evaluate_step(
            unet_state, 
            text_encoder_state, 
            vae_params, 
            save_path, 
            step,
            reward_model,
            pipeline,
            pipeline_params,
        ):
        
        pipeline_params["text_encoder"] = jax_utils.unreplicate(text_encoder_state.params)
        pipeline_params["vae"] = jax_utils.unreplicate(vae_params)
        pipeline_params["unet"] = jax_utils.unreplicate(unet_state.params)

        pipeline_params = jax_utils.replicate(pipeline_params)

        rewards = 0.0
        for i in range(len(generated_prompts)):
            prompt = generated_prompts[i]
            num_samples = jax.device_count()
            prompts = num_samples * [prompt]
            prompt_ids = pipeline.prepare_inputs(prompts)

            eval_rng = jax.random.PRNGKey(0)
            eval_seed = jax.random.split(eval_rng, num_samples)
            prompt_ids = shard(prompt_ids)
            images = pipeline(prompt_ids, pipeline_params, eval_seed, jit=True).images
            images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
            for j, image in enumerate(images):
                image_filename = save_path + f"{prompt}-{step}-{j}.jpg"
                image.save(image_filename)

            clean_prompt = remove_unique_token(prompt)
            rewards += reward_model.get_reward(images, [clean_prompt] * num_samples, args.reward_lambda).mean()
        return rewards / len(generated_prompts)

    # Create parallel version of the train step
    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, 1))

    # Replicate the train state on each device
    unet_state = jax_utils.replicate(unet_state)
    text_encoder_state = jax_utils.replicate(text_encoder_state)
    vae_params = jax_utils.replicate(vae_params)

    # Train!
    num_update_steps_per_epoch = math.ceil(len(train_dataloader))

    # Scheduler and math around the number of training steps.
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch

    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")

    def checkpoint(step=None):
        # Create the pipeline using the trained modules and save it.
        safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
            "CompVis/stable-diffusion-safety-checker", from_pt=True
        )
        feature_extractor = CLIPImageProcessor.from_pretrained(
            "stabilityai/stable-diffusion-2",
            subfolder="feature_extractor",
        )
        pipeline = FlaxStableDiffusionPipeline(
            text_encoder=text_encoder,
            vae=vae,
            unet=unet,
            tokenizer=tokenizer,
            scheduler=noise_scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )

        outdir = os.path.join(args.savepath, str(step)) if step else args.savepath
        print(f"Saving model to {outdir}")
        pipeline.save_pretrained(
            outdir,
            params={
                "text_encoder": get_params_to_save(text_encoder_state.params),
                "vae": get_params_to_save(vae_params),
                "unet": get_params_to_save(unet_state.params),
                "safety_checker": safety_checker.params,
            },
        )

        if args.push_to_hub:
            message = f"checkpoint-{step}" if step is not None else "End of training"
            upload_folder(
                repo_id=repo_id,
                folder_path=args.savepath,
                commit_message=message,
                ignore_patterns=["step_*", "epoch_*"],
            )

    global_step = 0

    epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0)
    train_metrics = []
    os.makedirs(f'figs/rpo/{args.subject}', exist_ok=True)
    validate_path = f'validate/rpo/{args.subject}/'
    os.makedirs(validate_path, exist_ok=True)

    max_reward = 0
    start = time.time()
    for epoch in epochs:
        # -------------------------- Training ------------------------ #
        steps_per_epoch = len(train_dataset) // total_train_batch_size
        train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
        # train
        for batch in train_dataloader:
            batch = shard(batch)
            unet_state, text_encoder_state, train_metric, train_rngs = p_train_step(
                unet_state, text_encoder_state, vae_params, batch, train_rngs
            )
            print("train metric: ", train_metric)
            train_step_progress_bar.update(jax.local_device_count())

            global_step += 1
            
            # ------------------------- Validation ----------------------- #
            if global_step % args.eval_steps == 0:
                rewards = evaluate_step(
                    unet_state, 
                    text_encoder_state, 
                    vae_params, 
                    validate_path, 
                    global_step,
                    train_dataset.reward_model,
                    pipeline,
                    pipeline_params,)
                print("-" * 100)
                print(f"Current gradient step: {global_step}")
                print(f"average reward for validation: {rewards:.4f}")
                train_metrics.append(rewards.item())
                x = np.arange(10, global_step + 1, args.eval_steps)
                plt.plot(x, train_metrics, label="Validation Reward")
                plt.xlabel("Gradient Steps")
                plt.ylabel("Reward")
                plt.legend()
                plt.savefig(f"figs/rpo/{args.subject}/validate_reward.png")
                plt.clf()
                if rewards > max_reward:
                    print(f"Max reward increasing: {max_reward:.4f} -> {rewards:.4f}")
                    max_reward = rewards
                    shutil.rmtree(args.savepath)
                    os.makedirs(args.savepath, exist_ok=True)
                    checkpoint()
                    print("checkpoint saved!")
                    end = time.time()

            if global_step >= args.max_train_steps:
                break

        print(f"train_metrics: {train_metrics}")
        train_step_progress_bar.close()
    duration = end - start
    mins, secs = divmod(duration, 60)
    print(f"Latency report time: {int(mins)} minutes and {secs:.2f} seconds")
    os.makedirs('path-to-results', exist_ok=True)

    latency = {"latency": duration}
    filename = "path-to-results/rpo_latency_" + args.subject + ".json"
    with open(filename, "w") as f:
        json.dump(latency, f, indent=4)
    destination_blob_name = f"lambda={args.reward_lambda}/validation_rewards/rpo/latency_{args.subject}.json"
    upload_file_to_gcs(filename, args.bucket, destination_blob_name)

    filename = "path-to-results/rpo_" + args.subject + ".npy"
    destination_blob_name = f"lambda={args.reward_lambda}/validation_rewards/rpo/{args.subject}.npy"
    np.save(filename, train_metrics)
    upload_file_to_gcs(filename, args.bucket, destination_blob_name)

    print("Finished training and uploading results!")

if __name__ == "__main__":
    main()