import torch
import torch.nn as nn
from PIL import Image
import sys
import os

cwd = os.getcwd()
sys.path.append(cwd)

from tqdm import tqdm
import random
from collections import defaultdict
import prompts as prompts_file
import numpy as np
import torch.utils.checkpoint as checkpoint
import wandb
import contextlib
import torchvision
from transformers import AutoProcessor, AutoModel
import sys
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.loaders import AttnProcsLayers
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
import datetime
import hpsv2
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
from accelerate.logging import get_logger    
from accelerate import Accelerator
from absl import app, flags
from ml_collections import config_flags
from aesthetic_scorer import AdvantagePredictor,AestheticScorerDiff

from diffusers_patch.ddim_with_kl import ddim_step_KL


FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config", "config/fderc_init_learner.py:aesthetic")

from accelerate.utils import set_seed, ProjectConfiguration
logger = get_logger(__name__)
        
def init_dist_loss_fn(alpha = 1,device = None, torch_dtype=None):
    model_path = 'model/AWR_v2/_2024.01.12_19.21.30/AWR_model_50.pth'
    state_dict = torch.load(model_path)
    
    new_state_dict = {}
    for key, value in state_dict.items():
        new_key = key.replace("module.", "")  # Remove the "module." prefix
        new_state_dict[new_key] = value

    predictor = AdvantagePredictor()
    predictor.load_state_dict(new_state_dict)
    predictor = predictor.to(device, torch_dtype)
    predictor.requires_grad_(False)
    predictor.eval()
    
    for param in predictor.parameters():
        assert not param.requires_grad, "Some parameters are still requiring gradients."

    
    def loss_fn(latent):
        value_predictions = predictor(latent)
        return value_predictions
    return loss_fn



def main(_):
    config = FLAGS.config
    unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S")
    
    if not config.run_name:
        config.run_name = unique_id
    else:
        config.run_name += "_" + unique_id
    
    if config.resume_from:
        config.resume_from = os.path.normpath(os.path.expanduser(config.resume_from))
        if "checkpoint_" not in os.path.basename(config.resume_from):
            # get the most recent checkpoint in this directory
            checkpoints = list(filter(lambda x: "checkpoint_" in x, os.listdir(config.resume_from)))
            if len(checkpoints) == 0:
                raise ValueError(f"No checkpoints found in {config.resume_from}")
            config.resume_from = os.path.join(
                config.resume_from,
                sorted(checkpoints, key=lambda x: int(x.split("_")[-1]))[-1],
            )
        
    accelerator_config = ProjectConfiguration(
        project_dir=os.path.join(config.logdir, config.run_name),
        automatic_checkpoint_naming=True,
        total_limit=config.num_checkpoint_limit,
    )
    
    accelerator = Accelerator(
        log_with="wandb",
        mixed_precision=config.mixed_precision,
        project_config=accelerator_config,
        gradient_accumulation_steps=config.train.gradient_accumulation_steps,
    )
    
    if accelerator.is_main_process:
        wandb_args = {}
        wandb_args["name"] = config.run_name
        if config.debug:
            wandb_args.update({'mode':"disabled"})        
        accelerator.init_trackers(
            project_name="FDERC-final", config=config.to_dict(), init_kwargs={"wandb": wandb_args}
        )
        accelerator.project_configuration.project_dir = os.path.join(config.logdir, config.run_name)
        accelerator.project_configuration.logging_dir = os.path.join(config.logdir, wandb.run.name)
    
    logger.info(f"\n{config}")

    # set seed (device_specific is very important to get different prompts on different devices)
    set_seed(config.seed, device_specific=True)
    
    # load scheduler, tokenizer and models.
    if config.pretrained.model.endswith(".safetensors") or config.pretrained.model.endswith(".ckpt"):
        pipeline = StableDiffusionPipeline.from_single_file(config.pretrained.model)
    else:
        pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision)

    # def init_weights_to_zero(m):
    #     if hasattr(m, 'weight') and m.weight is not None:
    #         nn.init.zeros_(m.weight)
    #     if hasattr(m, 'bias') and m.bias is not None:
    #         nn.init.zeros_(m.bias)
    # pipeline.unet.apply(init_weights_to_zero)
    
    # freeze parameters of models to save more memory

    pipeline.vae.requires_grad_(False)
    pipeline.text_encoder.requires_grad_(False)
    pipeline.unet.requires_grad_(False)
    
    # disable safety checker
    pipeline.safety_checker = None    
    
    # make the progress bar nicer
    pipeline.set_progress_bar_config(
        position=1,
        disable=not accelerator.is_local_main_process,
        leave=False,
        desc="Timestep",
        dynamic_ncols=True,
    )    

    # switch to DDIM scheduler
    pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
    pipeline.scheduler.set_timesteps(config.steps)

    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.    
    inference_dtype = torch.float32

    # if accelerator.mixed_precision == "fp16":
    #     inference_dtype = torch.float16
    # elif accelerator.mixed_precision == "bf16":
    #     inference_dtype = torch.bfloat16    

    # Move unet, vae and text_encoder to device and cast to inference_dtype
    pipeline.vae.to(accelerator.device, dtype=inference_dtype)
    pipeline.text_encoder.to(accelerator.device, dtype=inference_dtype)
    pipeline.unet.to(accelerator.device, dtype=inference_dtype)        

    # Set correct lora layers
    lora_attn_procs = {}
    for name in pipeline.unet.attn_processors.keys():
        cross_attention_dim = (
            None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim
        )
        if name.startswith("mid_block"):
            hidden_size = pipeline.unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = pipeline.unet.config.block_out_channels[block_id]

        lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
    pipeline.unet.set_attn_processor(lora_attn_procs)

    # this is a hack to synchronize gradients properly. the module that registers the parameters we care about (in
    # this case, AttnProcsLayers) needs to also be used for the forward pass. AttnProcsLayers doesn't have a
    # `forward` method, so we wrap it to add one and capture the rest of the unet parameters using a closure.
    class _Wrapper(AttnProcsLayers):
        def forward(self, *args, **kwargs):
            return pipeline.unet(*args, **kwargs)

    unet = _Wrapper(pipeline.unet.attn_processors)        

    # set up diffusers-friendly checkpoint saving with Accelerate

    def save_model_hook(models, weights, output_dir):
        assert len(models) == 1
        if isinstance(models[0], AttnProcsLayers):
            pipeline.unet.save_attn_procs(output_dir)
        else:
            raise ValueError(f"Unknown model type {type(models[0])}")
        weights.pop()  # ensures that accelerate doesn't try to handle saving of the model

    def load_model_hook(models, input_dir):
        assert len(models) == 1
        if config.soup_inference:
            tmp_unet = UNet2DConditionModel.from_pretrained(
                config.pretrained.model, revision=config.pretrained.revision, subfolder="unet"
            )
            tmp_unet.load_attn_procs(input_dir)
            if config.resume_from_2 != "stablediffusion":
                tmp_unet_2 = UNet2DConditionModel.from_pretrained(
                    config.pretrained.model, revision=config.pretrained.revision, subfolder="unet"
                )
                tmp_unet_2.load_attn_procs(config.resume_from_2)
                
                attn_state_dict_2 = AttnProcsLayers(tmp_unet_2.attn_processors).state_dict()
                
            attn_state_dict = AttnProcsLayers(tmp_unet.attn_processors).state_dict()
            if config.resume_from_2 == "stablediffusion":
                for attn_state_key, attn_state_val in attn_state_dict.items():
                    attn_state_dict[attn_state_key] = attn_state_val*config.mixing_coef_1
            else:
                for attn_state_key, attn_state_val in attn_state_dict.items():
                    attn_state_dict[attn_state_key] = attn_state_val*config.mixing_coef_1 + attn_state_dict_2[attn_state_key]*(1.0 - config.mixing_coef_1)
            
            models[0].load_state_dict(attn_state_dict)
                    
            del tmp_unet                
            
            if config.resume_from_2 != "stablediffusion":
                del tmp_unet_2
                
        elif isinstance(models[0], AttnProcsLayers):
            tmp_unet = UNet2DConditionModel.from_pretrained(
                config.pretrained.model, revision=config.pretrained.revision, subfolder="unet"
            )
            tmp_unet.load_attn_procs(input_dir)
            models[0].load_state_dict(AttnProcsLayers(tmp_unet.attn_processors).state_dict())
            del tmp_unet
        else:
            raise ValueError(f"Unknown model type {type(models[0])}")
        models.pop()  # ensures that accelerate doesn't try to handle loading of the model

    accelerator.register_save_state_pre_hook(save_model_hook)
    accelerator.register_load_state_pre_hook(load_model_hook)

    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if config.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True
    

    optimizer = torch.optim.AdamW(
        unet.parameters(),
        lr=config.train.learning_rate,
        betas=(config.train.adam_beta1, config.train.adam_beta2),
        eps=config.train.adam_epsilon,
    )

    prompt_fn = getattr(prompts_file, config.prompt_fn)

    if config.eval_prompt_fn == '':
        eval_prompt_fn = prompt_fn
    else:
        eval_prompt_fn = getattr(prompts_file, config.eval_prompt_fn)

    # generate negative prompt embeddings
    neg_prompt_embed = pipeline.text_encoder(
        pipeline.tokenizer(
            [""],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=pipeline.tokenizer.model_max_length,
        ).input_ids.to(accelerator.device)
    )[0]

    train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size_per_gpu_available, 1, 1)

    autocast = contextlib.nullcontext
    
    
    # Prepare everything with our `accelerator`.
    unet, optimizer = accelerator.prepare(unet, optimizer)
    
    if config.reward_fn=='aesthetic': # easthetic
        loss_fn = init_dist_loss_fn(alpha = config.train.kl_weight/config.train.loss_coeff,
                                    device = accelerator.device,
                                    torch_dtype = inference_dtype)
    else:
        raise NotImplementedError
    
    keep_input = True
    timesteps = pipeline.scheduler.timesteps #[981, 961, 941, 921,]

    if config.resume_from:
        logger.info(f"Resuming from {config.resume_from}")
        accelerator.load_state(config.resume_from)
        first_epoch = int(config.resume_from.split("_")[-1]) + 1
    else:
        first_epoch = 0 
       
    global_step = 0

    #################### TRAINING ####################        
    for epoch in list(range(first_epoch, config.num_epochs)):
        unet.train()
        info = defaultdict(list)
        
        for inner_iters in tqdm(
                list(range(config.train.data_loader_iterations)),
                position=0,
                disable=not accelerator.is_local_main_process
            ):
            latent = torch.zeros((config.train.batch_size_per_gpu_available, 4, 64, 64),
                device=accelerator.device, dtype=inference_dtype)    # start from x_fix point

            if accelerator.is_main_process:
                logger.info(f"{config.run_name.rsplit('/', 1)[0]} Epoch {epoch}.{inner_iters}: training")

            prompts, prompt_metadata = zip(
                *[prompt_fn() for _ in range(config.train.batch_size_per_gpu_available)]
            )

            prompt_ids = pipeline.tokenizer(
                prompts,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=pipeline.tokenizer.model_max_length,
            ).input_ids.to(accelerator.device)   

            pipeline.scheduler.alphas_cumprod = pipeline.scheduler.alphas_cumprod.to(accelerator.device)
            prompt_embeds = pipeline.text_encoder(prompt_ids)[0]         
            
        
            with accelerator.accumulate(unet):
                with autocast():                                 
                    with torch.enable_grad(): 
                        
                        kl_loss = 0
                        for i, t in tqdm(
                            enumerate(timesteps), 
                            total=len(timesteps),
                            disable=not accelerator.is_local_main_process,
                        ):
                            t = torch.tensor([t],
                                    dtype=inference_dtype,
                                    device=latent.device
                                )
                            t = t.repeat(config.train.batch_size_per_gpu_available)

                            if config.grad_checkpoint:
                                    noise_pred_uncond = checkpoint.checkpoint(unet, latent, t, train_neg_prompt_embeds, use_reentrant=False).sample
                                    noise_pred_cond = checkpoint.checkpoint(unet, latent, t, prompt_embeds, use_reentrant=False).sample
                                    
                            else:
                                noise_pred_uncond = unet(latent, t, train_neg_prompt_embeds).sample
                                noise_pred_cond = unet(latent, t, prompt_embeds).sample

                            if config.truncated_backprop:
                                if config.truncated_backprop_rand:
                                    timestep = random.randint(
                                        config.truncated_backprop_minmax[0],
                                        config.truncated_backprop_minmax[1]
                                    )
                                    if i < timestep:
                                        noise_pred_uncond = noise_pred_uncond.detach()
                                        noise_pred_cond = noise_pred_cond.detach()
                                else:
                                    if i < config.trunc_backprop_timestep:
                                        noise_pred_uncond = noise_pred_uncond.detach()
                                        noise_pred_cond = noise_pred_cond.detach()

                            grad = (noise_pred_cond - noise_pred_uncond)
                            
                            noise_pred = noise_pred_uncond + config.sd_guidance_scale * grad
                            
                            old_noise_pred = torch.zeros_like(noise_pred).to(noise_pred.device, inference_dtype)       
                            # latent = pipeline.scheduler.step(noise_pred, t[0].long(), latent).prev_sample
                            
                            latent, kl_terms = ddim_step_KL(
                                pipeline.scheduler,
                                noise_pred,   # (2,4,64,64),
                                old_noise_pred, # (2,4,64,64),
                                t[0].long(),
                                latent,
                                eta=config.sample_eta,  # 1.0
                            )
                            kl_loss += torch.mean(kl_terms).to(inference_dtype)
                                        
                        ims = pipeline.vae.decode(latent.to(pipeline.vae.dtype) / 0.18215).sample
                    
                        value_predictions = loss_fn(latent) # Get the predicted values
                        loss = -value_predictions.mean() * config.train.loss_coeff 
                        
                        total_loss = loss + config.train.kl_weight * kl_loss
                        
                        values_mean = value_predictions.mean()
                        values_std = value_predictions.std()
                        
                        info['latent-norm (4,4,64,64)'].append(torch.norm(latent))
                        info["init-learner-loss"].append(total_loss)
                        info["init-learner-KL"].append(kl_loss)
                        info["init-values"].append(values_mean)
                        info["init-values_std"].append(values_std)
                    
                        # backward pass
                        accelerator.backward(total_loss)
                        if accelerator.sync_gradients:
                            accelerator.clip_grad_norm_(unet.parameters(), config.train.max_grad_norm)
                        optimizer.step()
                        optimizer.zero_grad()                        

            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                assert (
                    inner_iters + 1
                ) % config.train.gradient_accumulation_steps == 0
                # log training and evaluation 
                
                logger.info("Logging")
                
                info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
                info = accelerator.reduce(info, reduction="mean")
                logger.info(f"init-learner-loss: {info['init-learner-loss']}")

                info.update({"epoch": epoch, "inner_epoch": inner_iters})
                accelerator.log(info, step=global_step)

                global_step += 1
                info = defaultdict(list)

        # make sure we did an optimization step at the end of the inner epoch
        assert accelerator.sync_gradients

        if (epoch % config.save_freq == 0 or epoch == config.num_epochs - 1) and accelerator.is_main_process:
            accelerator.save_state()
            

if __name__ == "__main__":
    app.run(main)
