from collections import defaultdict
import contextlib
import os
import datetime
from concurrent import futures
import time
from absl import app, flags
from ml_collections import config_flags
from accelerate import Accelerator
from accelerate.utils import set_seed, ProjectConfiguration
from accelerate.logging import get_logger
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
import numpy as np
from ddpo_pytorch.stat_tracking import PerPromptStatTracker
from ddpo_pytorch.diffusers_patch.ddim_with_logprob import ddim_step_with_logprob
from diffusers_patch.pipeline_using_FreeDoM import pipeline_using_freedom
import rewards
import torch
import wandb
from functools import partial
import tqdm
import tempfile
from PIL import Image
from DiffusionSampler import DiffusionModelSampler

tqdm = partial(tqdm.tqdm, dynamic_ncols=True)

class FreeDoM(DiffusionModelSampler):

    def __init__(self, config):
        super().__init__(config)
    
    def prepare_prompts_and_rewards(self):
        super().prepare_prompts_and_rewards()

        # Retrieve the reward function from ddpo_pytorch.rewards using the config
        print(self.config.reward_fn)
        if (self.config.reward_fn=='hps' or self.config.reward_fn=='hps_score'):
            self.reward_fn = rewards.hps_score(inference_dtype = self.inference_dtype, device = self.accelerator.device, accelerator=self.accelerator)
            self.loss_fn = rewards.hps_score(inference_dtype = self.inference_dtype, device = self.accelerator.device, return_loss=True)
        elif (self.config.reward_fn=='aesthetic' or self.config.reward_fn=='aesthetic_score'): # aesthetic
            self.reward_fn = rewards.aesthetic_score(torch_dtype=self.inference_dtype, device = self.accelerator.device, accelerator=self.accelerator)
            self.loss_fn = rewards.aesthetic_score(torch_dtype=self.inference_dtype, device = self.accelerator.device, return_loss=True)
        elif (self.config.reward_fn=='inpaint'):
            self.reward_fn, self.masked_target = rewards.inpaint(x=self.config.inpaint.x, width=self.config.inpaint.width, y=self.config.inpaint.y, height=self.config.inpaint.height, sample_name=self.config.inpaint.sample_name, return_loss=False)
            self.loss_fn, self.masked_target = rewards.inpaint(x=self.config.inpaint.x, width=self.config.inpaint.width, y=self.config.inpaint.y, height=self.config.inpaint.height, sample_name=self.config.inpaint.sample_name, return_loss=True)
        else:
            NotImplementedError

        self.hps_fn = rewards.hps_score(inference_dtype = self.inference_dtype, device = self.accelerator.device, accelerator=self.accelerator)

    def sample_images(self, train=False):
        """Sample images using the diffusion model."""
        self.pipeline.unet.eval()
        samples = []

        num_prompts_per_gpu = self.config.sample.batch_size
        self.sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(num_prompts_per_gpu, 1, 1)

        # Generate prompts
        prompts, prompt_metadata = self.eval_prompts, self.eval_prompt_metadata
        print(prompts)

        print("Same evaluation: ", self.config.same_evaluation)
        if self.config.same_evaluation:
            generator = torch.cuda.manual_seed(self.config.seed)
            latents_0 = torch.randn(
                (self.config.sample.batch_size*self.config.max_vis_images, self.pipeline.unet.config.in_channels, self.pipeline.unet.sample_size, self.pipeline.unet.sample_size),
                device=self.accelerator.device,
                dtype=self.inference_dtype,
                generator=generator
            )     
        else:
            latents_0 = torch.randn(
                (self.config.sample.batch_size*self.config.max_vis_images, self.pipeline.unet.config.in_channels, self.pipeline.unet.sample_size, self.pipeline.unet.sample_size),
                device=self.accelerator.device,
                dtype=self.inference_dtype,
            )
        
        with torch.no_grad():
            for vis_idx in tqdm(
                range(self.config.max_vis_images),
                desc=f"Sampling images",
                disable=not self.accelerator.is_local_main_process,
                position=0,
            ):
                prompts_batch = prompts[vis_idx*num_prompts_per_gpu : (vis_idx+1)*num_prompts_per_gpu]
                print(prompts_batch)

                latents_batch = latents_0[vis_idx*num_prompts_per_gpu : (vis_idx+1)*num_prompts_per_gpu]

                # Encode prompts
                prompt_ids = self.pipeline.tokenizer(
                    prompts_batch,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True,
                    max_length=self.pipeline.tokenizer.model_max_length,
                ).input_ids.to(self.accelerator.device)
                prompt_embeds = self.pipeline.text_encoder(prompt_ids)[0]

                # convert reward function to get image as only input
                image_reward_fn = lambda images: self.reward_fn(
                    images, 
                    prompts_batch
                )

                # Sample images
                with self.autocast():
                    images, _, latents  \
                    = pipeline_using_freedom(
                        self.pipeline,
                        prompt_embeds=prompt_embeds,
                        negative_prompt_embeds=self.sample_neg_prompt_embeds,
                        num_inference_steps=self.config.sample.num_steps,
                        guidance_scale=self.config.sample.guidance_scale,
                        eta=self.config.sample.eta,
                        output_type="pt",
                        latents=latents_batch,
                        rho=self.config.freedom.rho,
                        time_travel_repeat=self.config.freedom.time_travel_repeat,
                        reward_fn=image_reward_fn
                    )
                latents = torch.stack(latents, dim=1)  # (batch_size, num_steps + 1, 4, 64, 64)
                print(latents.shape)
                latents_norm = torch.norm(latents.reshape(latents.shape[0], latents.shape[1], -1), dim=-1) # (batch_size, num_steps + 1)
                timesteps = torch.cat((self.pipeline.scheduler.timesteps, torch.tensor([1]).to(self.accelerator.device)))
                timesteps = timesteps.repeat(
                    self.config.sample.batch_size, 1
                )  # (batch_size, num_steps + 1)
                print(timesteps.shape)
                # self.info_eval["latents_norm"].append(latents_norm)
                # self.info_eval["timesteps"].append(timesteps)

                rewards = self.reward_fn(images, prompts_batch)
                print(rewards)

                self.info_eval_vis["eval_rewards_img"].append(rewards.clone().detach())
                self.info_eval_vis["eval_image"].append(images.clone().detach())
                self.info_eval_vis["eval_prompts"] = list(self.info_eval_vis["eval_prompts"]) + list(prompts_batch)

                hps_score = self.hps_fn(images, prompts_batch)

                self.info_eval_vis["eval_hps_score_img"].append(hps_score.clone().detach())


FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config", "config/tds.py", "Sampling configuration.")

def main(_):
    # Load the configuration
    config = FLAGS.config
    print(config)

    # Initialize the trainer with the configurationf
    sampler = FreeDoM(config)

    # Run the training
    sampler.run_evaluation()

if __name__ == "__main__":
    app.run(main)
