import torch
from PIL import Image
import sys
import os

cwd = os.getcwd()
sys.path.append(cwd)

from aesthetic_scorer import AestheticScorerDiff
from tqdm import tqdm
import random
from collections import defaultdict
import prompts as prompts_file
import numpy as np
import torch.utils.checkpoint as checkpoint
import wandb
import contextlib
import torchvision
from transformers import AutoProcessor, AutoModel
import sys
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.loaders import AttnProcsLayers
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
import datetime
import hpsv2
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
from accelerate.logging import get_logger    
from accelerate import Accelerator
from absl import app, flags
from ml_collections import config_flags

from diffusers_patch.ddim_with_kl import ddim_step_KL


FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config", "config/align_prop.py:aesthetic")

from accelerate.utils import set_seed, ProjectConfiguration
logger = get_logger(__name__)


def hps_loss_fn(inference_dtype=None, device=None):
    model_name = "ViT-H-14"
    model, preprocess_train, preprocess_val = create_model_and_transforms(
        model_name,
        'laion2B-s32B-b79K',
        precision=inference_dtype,
        device=device,
        jit=False,
        force_quick_gelu=False,
        force_custom_text=False,
        force_patch_dropout=False,
        force_image_size=None,
        pretrained_image=False,
        image_mean=None,
        image_std=None,
        light_augmentation=True,
        aug_cfg={},
        output_dict=True,
        with_score_predictor=False,
        with_region_predictor=False
    )    
    
    tokenizer = get_tokenizer(model_name)
    
    checkpoint_path = f"{os.path.expanduser('~')}/.cache/hpsv2/HPS_v2_compressed.pt"
    # force download of model via score
    hpsv2.score([], "")
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    tokenizer = get_tokenizer(model_name)
    model = model.to(device, dtype=inference_dtype)
    model.eval()

    target_size =  224
    normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                std=[0.26862954, 0.26130258, 0.27577711])
        
    def loss_fn(im_pix, prompts):    
        im_pix = ((im_pix / 2) + 0.5).clamp(0, 1) 
        x_var = torchvision.transforms.Resize(target_size, antialias=False)(im_pix)
        x_var = normalize(x_var).to(im_pix.dtype)        
        caption = tokenizer(prompts)
        caption = caption.to(device)
        outputs = model(x_var, caption)
        image_features, text_features = outputs["image_features"], outputs["text_features"]
        logits = image_features @ text_features.T
        scores = torch.diagonal(logits)
        loss = 1.0 - scores
        return  loss, scores
    
    return loss_fn
    

def AWR_aesthetic_loss(
                    imgs,
                    adv_predictions,
                    alpha=1.0,
                    device=None,
                    torch_dtype=None,
                    aesthetic_target=10
                    ):
    
    target_size = 224
    normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                                std=[0.26862954, 0.26130258, 0.27577711])
    scorer = AestheticScorerDiff(dtype=torch_dtype).to(device, dtype=torch_dtype)
    scorer.requires_grad_(False)
    target_size = 224

    im_pix = ((imgs / 2) + 0.5).clamp(0, 1) 
    im_pix = torchvision.transforms.Resize(target_size, antialias=False)(im_pix)
    im_pix = normalize(im_pix).to(imgs.dtype)
    rewards = scorer(im_pix)

    loss = torch.square(adv_predictions-rewards)
    return loss


def main(_):
    config = FLAGS.config
    unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S")
    
    if not config.run_name:
        config.run_name = unique_id
    else:
        config.run_name += "_" + unique_id
    
    if config.resume_from:
        config.resume_from = os.path.normpath(os.path.expanduser(config.resume_from))
        if "checkpoint_" not in os.path.basename(config.resume_from):
            # get the most recent checkpoint in this directory
            checkpoints = list(filter(lambda x: "checkpoint_" in x, os.listdir(config.resume_from)))
            if len(checkpoints) == 0:
                raise ValueError(f"No checkpoints found in {config.resume_from}")
            config.resume_from = os.path.join(
                config.resume_from,
                sorted(checkpoints, key=lambda x: int(x.split("_")[-1]))[-1],
            )
        
    accelerator_config = ProjectConfiguration(
        project_dir=os.path.join(config.logdir, config.run_name),
        automatic_checkpoint_naming=True,
        total_limit=config.num_checkpoint_limit,
    )
    
    accelerator = Accelerator(
        log_with="wandb",
        mixed_precision=config.mixed_precision,
        project_config=accelerator_config,
        gradient_accumulation_steps=config.train.gradient_accumulation_steps,
    )
    
    if accelerator.is_main_process:
        wandb_args = {}
        wandb_args["name"] = config.run_name
        if config.debug:
            wandb_args.update({'mode':"disabled"})        
        accelerator.init_trackers(
            project_name="FDERC-final", config=config.to_dict(), init_kwargs={"wandb": wandb_args}
        )
        accelerator.project_configuration.project_dir = os.path.join(config.logdir, config.run_name)
        accelerator.project_configuration.logging_dir = os.path.join(config.logdir, wandb.run.name)
    
    logger.info(f"\n{config}")

    # set seed (device_specific is very important to get different prompts on different devices)
    set_seed(config.seed, device_specific=True)
    
    # load scheduler, tokenizer and models.
    if config.pretrained.model.endswith(".safetensors") or config.pretrained.model.endswith(".ckpt"):
        pipeline = StableDiffusionPipeline.from_single_file(config.pretrained.model)
    else:
        pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision)
    
    # freeze parameters of models to save more memory
    pipeline.vae.requires_grad_(False)
    pipeline.text_encoder.requires_grad_(False)
    pipeline.unet.requires_grad_(False)
    
    # disable safety checker
    pipeline.safety_checker = None    
    
    # make the progress bar nicer
    pipeline.set_progress_bar_config(
        position=1,
        disable=not accelerator.is_local_main_process,
        leave=False,
        desc="Timestep",
        dynamic_ncols=True,
    )    

    # switch to DDIM scheduler
    pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
    pipeline.scheduler.set_timesteps(config.steps)

    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.    
    inference_dtype = torch.float32

    # if accelerator.mixed_precision == "fp16":
    #     inference_dtype = torch.float16
    # elif accelerator.mixed_precision == "bf16":
    #     inference_dtype = torch.bfloat16    

    # Move unet, vae and text_encoder to device and cast to inference_dtype
    pipeline.vae.to(accelerator.device, dtype=inference_dtype)
    pipeline.text_encoder.to(accelerator.device, dtype=inference_dtype)
    pipeline.unet.to(accelerator.device, dtype=inference_dtype)        

    unet = pipeline.unet
    
    if config.reward_fn=='aesthetic':
        from aesthetic_scorer import AdvantagePredictor
        predictor = AdvantagePredictor()
        predictor = predictor.to(accelerator.device)
        predictor.requires_grad_(True)
    else:
        raise NotImplementedError

    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if config.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True
    

    optimizer = torch.optim.AdamW(
        predictor.parameters(),
        lr=config.train.learning_rate,
        betas=(config.train.adam_beta1, config.train.adam_beta2),
        eps=config.train.adam_epsilon,
    )

    prompt_fn = getattr(prompts_file, config.prompt_fn)

    if config.eval_prompt_fn == '':
        eval_prompt_fn = prompt_fn
    else:
        eval_prompt_fn = getattr(prompts_file, config.eval_prompt_fn)

    # generate negative prompt embeddings
    neg_prompt_embed = pipeline.text_encoder(
        pipeline.tokenizer(
            [""],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=pipeline.tokenizer.model_max_length,
        ).input_ids.to(accelerator.device)
    )[0]

    train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size_per_gpu_available, 1, 1)

    autocast = contextlib.nullcontext
    
    
    # Prepare everything with our `accelerator`.
    unet, predictor, optimizer = accelerator.prepare(unet, predictor, optimizer)
    
    # if config.reward_fn=='hps':
    #     loss_fn = hps_loss_fn(inference_dtype, accelerator.device)
    # elif config.reward_fn=='aesthetic': # asthetic
    #     loss_fn = AWR_aesthetic_loss_fn(grad_scale=config.grad_scale,
    #                                 aesthetic_target=config.aesthetic_target,
    #                                 accelerator = accelerator,
    #                                 torch_dtype = inference_dtype,
    #                                 device = accelerator.device)

    keep_input = True
    timesteps = pipeline.scheduler.timesteps #[981, 961, 941, 921,]

    if config.resume_from:
        logger.info(f"Resuming from {config.resume_from}")
        accelerator.load_state(config.resume_from)
        first_epoch = int(config.resume_from.split("_")[-1]) + 1
    else:
        first_epoch = 0 
       
    global_step = 0

    #################### TRAINING ####################        
    for epoch in list(range(first_epoch, config.num_epochs)):
        unet.eval()
        info = defaultdict(list)
        
        for inner_iters in tqdm(
                list(range(config.train.data_loader_iterations)),
                position=0,
                disable=not accelerator.is_local_main_process
            ):
            latent = torch.randn((config.train.batch_size_per_gpu_available, 4, 64, 64),
                device=accelerator.device, dtype=inference_dtype)    
            
            original_latents = latent

            if accelerator.is_main_process:
                logger.info(f"{config.run_name.rsplit('/', 1)[0]} Epoch {epoch}.{inner_iters}: training")

            prompts, prompt_metadata = zip(
                *[prompt_fn() for _ in range(config.train.batch_size_per_gpu_available)]
            )

            prompt_ids = pipeline.tokenizer(
                prompts,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=pipeline.tokenizer.model_max_length,
            ).input_ids.to(accelerator.device)   

            pipeline.scheduler.alphas_cumprod = pipeline.scheduler.alphas_cumprod.to(accelerator.device)
            prompt_embeds = pipeline.text_encoder(prompt_ids)[0]         
            
        
            with accelerator.accumulate(predictor):
                with autocast():                                 
                    with torch.no_grad(): 
                        for i, t in tqdm(
                            enumerate(timesteps), 
                            total=len(timesteps),
                            disable=not accelerator.is_local_main_process,
                        ):
                            t = torch.tensor([t],
                                    dtype=inference_dtype,
                                    device=latent.device
                                )
                            t = t.repeat(config.train.batch_size_per_gpu_available)

                            noise_pred_uncond = unet(latent, t, train_neg_prompt_embeds).sample
                            noise_pred_cond = unet(latent, t, prompt_embeds).sample    
                            grad = (noise_pred_cond - noise_pred_uncond)
                            noise_pred = noise_pred_uncond + config.sd_guidance_scale * grad
                                            
                            latent = pipeline.scheduler.step(
                                        noise_pred, t[0].long(), 
                                        latent, 
                                        config.sample_eta).prev_sample
                                        
                        ims = pipeline.vae.decode(latent.to(pipeline.vae.dtype) / 0.18215).sample
                    
                    adv_predictions = predictor(original_latents)
                    loss = AWR_aesthetic_loss(
                                ims,
                                adv_predictions,
                                alpha=config.train.kl_weight/config.train.loss_coeff,
                                torch_dtype = inference_dtype,
                                device = accelerator.device
                                )
                    
                    # loss =  loss.sum()
                    # loss = loss/config.train.batch_size_per_gpu_available
                    loss = loss.mean()
                    
                    info["AWR-loss"].append(loss)
                
                    # backward pass
                    accelerator.backward(loss)
                    if accelerator.sync_gradients:
                        accelerator.clip_grad_norm_(predictor.parameters(), config.train.max_grad_norm)
                    optimizer.step()
                    optimizer.zero_grad()                        

            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                assert (
                    inner_iters + 1
                ) % config.train.gradient_accumulation_steps == 0
                # log training and evaluation 
                
                logger.info("Logging")
                
                info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
                info = accelerator.reduce(info, reduction="mean")
                logger.info(f"AWR-loss: {info['AWR-loss']}")

                info.update({"epoch": epoch, "inner_epoch": inner_iters})
                accelerator.log(info, step=global_step)

                global_step += 1
                info = defaultdict(list)

        # make sure we did an optimization step at the end of the inner epoch
        assert accelerator.sync_gradients
        
        def save_cnn_checkpoint(model, checkpoint_dir, epoch, checkpoint_name="AWR_model"):
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            checkpoint_path = os.path.join(checkpoint_dir, f"{checkpoint_name}_{epoch}.pth")
            torch.save(model.state_dict(), checkpoint_path)
        
        if (epoch % config.save_freq == 0 or epoch == config.num_epochs - 1) and accelerator.is_main_process:
            model_dir = os.path.join(os.getcwd(), 'model', f"{config.run_name}")
            save_cnn_checkpoint(predictor, model_dir, epoch, checkpoint_name="AWR_model")
            

if __name__ == "__main__":
    app.run(main)
