from collections import defaultdict
import contextlib
import os
import datetime
from concurrent import futures
import time
from absl import app, flags
from ml_collections import config_flags
from accelerate import Accelerator
from accelerate.utils import set_seed, ProjectConfiguration
from accelerate.logging import get_logger
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
import numpy as np
from diffusers_patch.pipeline_using_SMC import pipeline_using_smc
import rewards
import torch
import wandb
from functools import partial
import tqdm
import tempfile
from PIL import Image
from DiffusionSampler import DiffusionModelSampler
import matplotlib.pyplot as plt

tqdm = partial(tqdm.tqdm, dynamic_ncols=True)

class SMC(DiffusionModelSampler):

    def __init__(self, config):
        super().__init__(config)

    def sample_images(self, train=False):
        """Sample images using the diffusion model."""

        # Prepare vae with accelerator
        self.pipeline.vae = self.accelerator.prepare(self.pipeline.vae)
        samples = []

        num_prompts_per_gpu = 1 if self.config.smc.num_particles >= self.config.sample.batch_size else int(self.config.sample.batch_size / self.config.smc.num_particles)
        self.sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(num_prompts_per_gpu, 1, 1)
        batch_p = min(self.config.smc.num_particles, self.config.sample.batch_size)

        # Generate prompts and latents
        prompts, prompt_metadata = self.eval_prompts, self.eval_prompt_metadata

        print("Same evaluation: ", self.config.same_evaluation)
        if self.config.same_evaluation:
            generator = torch.cuda.manual_seed(self.config.seed)
            latents_0 = torch.randn(
                (self.config.smc.num_particles*self.config.sample.batch_size*self.config.max_vis_images, self.pipeline.unet.config.in_channels, self.pipeline.unet.sample_size, self.pipeline.unet.sample_size),
                device=self.accelerator.device,
                dtype=self.inference_dtype,
                generator=generator
            )     
        else:
            latents_0 = torch.randn(
                (self.config.smc.num_particles*self.config.sample.batch_size*self.config.max_vis_images, self.pipeline.unet.config.in_channels, self.pipeline.unet.sample_size, self.pipeline.unet.sample_size),
                device=self.accelerator.device,
                dtype=self.inference_dtype,
            ) 

        with torch.no_grad():
            for vis_idx in tqdm(
                range(int(self.config.max_vis_images*self.config.sample.batch_size/num_prompts_per_gpu)),
                desc=f"Sampling images",
                disable=not self.accelerator.is_local_main_process,
                position=0,
            ):
                prompts_batch = prompts[vis_idx*num_prompts_per_gpu : (vis_idx+1)*num_prompts_per_gpu]
                print(prompts_batch)
                repeated_prompts = [prompt for prompt in prompts_batch for _ in range(batch_p)]
                
                latents_batch = latents_0[vis_idx*self.config.smc.num_particles*num_prompts_per_gpu : (vis_idx+1)*self.config.smc.num_particles*num_prompts_per_gpu]
    
                # Encode prompts
                prompt_ids = self.pipeline.tokenizer(
                    prompts_batch,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True,
                    max_length=self.pipeline.tokenizer.model_max_length,
                ).input_ids.to(self.accelerator.device)
                prompt_embeds = self.pipeline.text_encoder(prompt_ids)[0]
    
                # convert reward function to get image as only input
                image_reward_fn = lambda images: self.reward_fn(
                    images, 
                    repeated_prompts
                )
                
                # Sample images
                with self.autocast():
                    print(batch_p)
                    images, _, log_w, normalized_w, latents, \
                    all_log_w, resample_indices, ess_trace, \
                    scale_factor_trace, rewards_trace, manifold_deviation_trace, log_prob_diffusion_trace \
                    = pipeline_using_smc(
                        self.pipeline,
                        prompt_embeds=prompt_embeds,
                        negative_prompt_embeds=self.sample_neg_prompt_embeds,
                        num_inference_steps=self.config.sample.num_steps,
                        guidance_scale=self.config.sample.guidance_scale,
                        eta=self.config.sample.eta,
                        output_type="pt",
                        latents=latents_batch,
                        num_particles=self.config.smc.num_particles,
                        batch_p=batch_p,
                        resample_strategy=self.config.smc.resample_strategy,
                        ess_threshold=self.config.smc.ess_threshold,
                        tempering=self.config.smc.tempering,
                        tempering_schedule=self.config.smc.tempering_schedule,
                        tempering_gamma=self.config.smc.tempering_gamma,
                        tempering_start=self.config.smc.tempering_start,
                        reward_fn=image_reward_fn,
                        penalty_coeff = self.config.smc.penalty_coeff,
                        generator=generator
                    )
                print(ess_trace)
                print(normalized_w)
                self.info_eval_vis["eval_ess"].append(ess_trace)
                self.info_eval_vis["scale_factor_trace"].append(scale_factor_trace)
                self.info_eval_vis["rewards_trace"].append(rewards_trace)
                self.info_eval_vis["manifold_deviation_trace"].append(manifold_deviation_trace)
                self.info_eval_vis["log_prob_diffusion_trace"].append(log_prob_diffusion_trace)
                latents = torch.stack(latents, dim=1)  # (batch_size*num_particles, num_steps + 1, 4, 64, 64)
                timesteps = self.pipeline.scheduler.timesteps.repeat(
                    self.config.sample.batch_size, 1
                )  # (batch_size, num_steps)
    
                rewards = self.reward_fn(images, prompts_batch)
                print(rewards)
                
                self.info_eval_vis["eval_rewards_img"].append(rewards.clone().detach())
                self.info_eval_vis["eval_image"].append(images.clone().detach())
                self.info_eval_vis["eval_prompts"] = list(self.info_eval_vis["eval_prompts"]) + list(prompts_batch)

                # hps_score = self.hps_fn(images, prompts_batch)

                # self.info_eval_vis["eval_hps_score_img"].append(hps_score.clone().detach())

    def log_evaluation(self, epoch=None, inner_epoch=None):
        super().log_evaluation(epoch=None, inner_epoch=None)

        rewards = torch.cat(self.info_eval_vis["eval_rewards_img"])
        hps_scores = torch.cat(self.info_eval_vis["eval_hps_score_img"])
        prompts = self.info_eval_vis["eval_prompts"]

        ess_trace = torch.cat(self.info_eval_vis["eval_ess"])
        scale_factor_trace = torch.cat(self.info_eval_vis["scale_factor_trace"])
        rewards_trace = torch.cat(self.info_eval_vis["rewards_trace"])
        manifold_deviation_trace = torch.cat(self.info_eval_vis["manifold_deviation_trace"])
        log_prob_diffusion_trace = torch.cat(self.info_eval_vis["log_prob_diffusion_trace"])
        
        for i, ess in enumerate(ess_trace):

            fig, ax1 = plt.subplots()
            ax2 = ax1.twinx()

            ax1.plot(range(len(ess)), ess, 'b-')
            # ax2.plot(range(len(ess)), scale_factor_trace[i], 'r-')
            caption = f"{i:03d}_{prompts[i]} | reward: {rewards[i]} | hps: {hps_scores[i]}"
            os.makedirs(f"{self.log_dir}/{caption}", exist_ok=True)

            plt.savefig(f"{self.log_dir}/{caption}/ess.png")
            plt.clf()

            plt.plot(rewards_trace[i])
            plt.savefig(f"{self.log_dir}/{caption}/intermediate_rewards.png")
            plt.clf()

            plt.plot(manifold_deviation_trace[i])
            plt.savefig(f"{self.log_dir}/{caption}/manifold_deviation.png")
            plt.clf()

            plt.plot(log_prob_diffusion_trace[i])
            plt.savefig(f"{self.log_dir}/{caption}/log_prob_diffusion.png")
            plt.clf()

            np.save(f"{self.log_dir}/{caption}/ess.npy", ess)
            np.save(f"{self.log_dir}/{caption}/manifold_deviation.npy", manifold_deviation_trace[i])
            np.save(f"{self.log_dir}/{caption}/log_prob_diffusion.npy", log_prob_diffusion_trace[i])




FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config", "config/smc.py", "Sampling configuration.")

def main(_):
    # Load the configuration
    config = FLAGS.config
    print(config)

    # Initialize the trainer with the configurationf
    sampler = SMC(config)

    # Run the training
    sampler.run_evaluation()

if __name__ == "__main__":
    app.run(main)
