import os
import time, torch
from accelerate import PartialState
from diffusers import AutoPipelineForText2Image, DiffusionPipeline, AutoencoderKL, BitsAndBytesConfig, SD3Transformer2DModel, StableDiffusion3Pipeline, StableDiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultistepScheduler
from diffusers import SanaPipeline, SanaSprintPipeline
from transformers import T5EncoderModel
from tqdm import tqdm

class FullPipeline:
    """
    Full pipeline for a text-to-image model

    Given a model name, it loads the corresponding pipeline and parallelizes it if needed, or loads it on a given device
    """

    def __init__(self, name, parallelize=False, device="cuda", torch_dtype="float16", local_path=None):
        self.name = name
        self.parallelize = parallelize
        self.device = device
        assert torch_dtype in ["float16", "bfloat16"], "torch_dtype must be either float16 or bfloat16"
        self.torch_dtype = torch.float16 if torch_dtype == "float16" else torch.bfloat16
        self.local_model_path = local_path
        self.model, self.distributed_state = self._load_model()
        self.max_length = self.model.tokenizer.model_max_length - 2 # 2 for starting and ending tokens
    
    def _load_model(self):
        """
        Load the full pipeline for the given model and parallelize it if needed
        
        Args:
            model_name (str): The name of the model
            parallelize (bool): Whether to parallelize the pipeline
            
        Returns:
            Pipeline: The pipeline loaded on the GPU or distributed across GPUs
            distributed_state: The distributed state if the pipeline is parallelized, None otherwise
        """
        # SD1.5, SD2 (Euler), SDXLT, and SANA are the models actually used to generate the final dataset
        if self.name == "SD15":
            model_id = self.local_model_path if self.local_model_path else "stable-diffusion-v1-5/stable-diffusion-v1-5"
            pipe = DiffusionPipeline.from_pretrained(model_id, safety_checker=None, torch_dtype=self.torch_dtype, use_safetensors=True)
        elif self.name == "SD2":
            model_id = self.local_model_path if self.local_model_path else "stabilityai/stable-diffusion-2"
            scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
            pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, safety_checker=None, torch_dtype=self.torch_dtype)
        elif self.name=="SANA":
            print(f"Loading {self.torch_dtype} version of SANA 1.6B")
            if self.torch_dtype==torch.float16:
                model_id = self.local_model_path if self.local_model_path else "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
                pipe = SanaPipeline.from_pretrained(model_id, variant="fp16", torch_dtype=self.torch_dtype)
                pipe.vae.to(torch.float32)
                pipe.text_encoder.to(torch.float32)
            elif self.torch_dtype==torch.bfloat16:
                model_id = self.local_model_path if self.local_model_path else "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
                pipe = SanaPipeline.from_pretrained(model_id, variant="bf16", torch_dtype=self.torch_dtype)
                pipe.vae.to(self.torch_dtype)
                pipe.text_encoder.to(self.torch_dtype)
            else:
                raise ValueError(f"Unknownd torch dtype {self.torch_dtype}")
            print("Compiling SANA transformer... ", end="")
            pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
            print("done!")
        elif self.name == "SDXLT_16b":
            vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=self.torch_dtype)
            pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=self.torch_dtype, variant="fp16", vae=vae)
            print("Compiling the UNet... ", end="") 
            pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
            print("done!")
        else:
            raise ValueError(f"Unknown model_name: {self.name}")

        if self.parallelize:
            print("Parallelizing... ", end="")
            distributed_state = PartialState()
            pipe.to(distributed_state.device)
            pipe.set_progress_bar_config(disable=True)
            print("done!")
            if "SD15" in self.name or "SD2" in self.name:
                print("Compiling the UNet... ", end="") 
                pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
                print("done!")
            return pipe, distributed_state
        else:
            print(f"Not parallelizing, loading to {self.device}... ", end="")
            pipe.to(self.device)
            pipe.set_progress_bar_config(disable=True)
            print("done!")
            if "SD15" in self.name or "SD2" in self.name:
                print("Compiling the UNet... ", end="") 
                pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
                print("done!")
            return pipe, None

    def truncate_captions(self, captions):
        input_ids = self.model.tokenizer(captions, truncation=True, padding="max_length", return_tensors="pt").input_ids
        truncated = [self.model.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
        return truncated
    
    def default_inference(self, captions_batch, steps, w, h, gs, seed):
        if seed is None:
            return self.model(captions_batch,num_inference_steps=steps,width=w,height=h,guidance_scale=gs).images
        return self.model(captions_batch,num_inference_steps=steps,width=w,height=h,guidance_scale=gs,generator=torch.manual_seed(seed)).images
    
    def embed_inference(self, prompt_embeds, negative_prompt_embeds, steps, w, h, gs, seed):
        if seed is None:
            return self.model(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, num_inference_steps=steps, width=w, height=h, guidance_scale=gs).images
        return self.model(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, num_inference_steps=steps, width=w, height=h, guidance_scale=gs, generator=torch.manual_seed(seed)).images

    def _sequential_generation(self, captions, filenames, batch_size, steps, size, gs, seed, dest_folder, makedir=True, inference_method=None, abs_path=False, resize=None, seed_list=None):
        assert len(captions) == len(filenames), "The number of captions and filenames must be the same"
        assert inference_method is not None, "Inference method must be provided"

        start = time.time()
        for i in tqdm(range(0,len(captions),batch_size)):
            captions_batch = captions[i:min(i+batch_size, len(captions))]
            filenames_batch = filenames[i:min(i+batch_size, len(filenames))]
            if seed_list:
                seed = seed_list[i]
            if "SANA" not in self.name:
                captions_batch = self.truncate_captions(captions_batch)
            result = inference_method(captions_batch, steps, size, size, gs, seed)
            for img, f in zip(result, filenames_batch):
                if resize: img = img.resize(resize)
                if abs_path: img.save(f)
                else:
                    if f.endswith(".jpg") or f.endswith(".jpeg") or f.endswith(".png"): img.save(f"{dest_folder}/{f}")
                    else: img.save(f"{dest_folder}/{f}.png")
        total_time = time.time() - start

        print(f"Total time: {total_time}s, Average time per image: {total_time/len(captions)}s")

    def _parallel_generation(self, captions, filenames, batch_size, steps, size, gs, seed, dest_folder, makedir=True, inference_method=None, abs_path=False, resize=None, seed_list=None):
        """
        Generate images in parallel using the given pipeline and save them to the destination folder
        
        Args:
            captions (list): The list of captions
            filenames (list): The list of filenames
            batch_size (int): The batch size to use for generation, for each process
            steps (int): The number of inference steps to use
            size (int): The size of the generated images
            gs (float): The guidance scale to use
            seed (int): The seed to use for generation
            dest_folder (str): The destination folder to save the generated images
            makedir (bool): Whether to create a new directory for the generated images
            inference_method (callable): The inference method to use for generation
            abs_path (bool): Whether the filenames are absolute paths or relative to the destination folder
        """
        assert len(captions) == len(filenames), "The number of captions and filenames must be the same"
        assert batch_size <= len(captions), "The batch size must be less than or equal to the number of captions"

        os.makedirs(dest_folder, exist_ok=True)
        if makedir: 
            dest_folder = f"{dest_folder}/{self.name}"
            os.makedirs(dest_folder, exist_ok=True)

        if seed_list:
            metadata = [{"caption": c, "filename": f, "seed": s} for c, f, s in zip(captions, filenames, seed_list)]
        else:
            metadata = [{"caption": c, "filename": f, "seed": seed} for c, f in zip(captions, filenames)]
        with self.distributed_state.split_between_processes(metadata) as metadata_split:
            start = time.time()
            for i in tqdm(range(0,len(metadata_split),batch_size)):
                metadata_batch = metadata_split[i:min(i+batch_size, len(metadata_split))]
                captions_batch = [x["caption"] for x in metadata_batch]
                if "SANA" not in self.name:
                    captions_batch = self.truncate_captions(captions_batch)
                filenames_batch = [x["filename"] for x in metadata_batch]
                seed = metadata_batch[0]["seed"]
                result = inference_method(captions_batch, steps, size, size, gs, seed)

                for img, f in zip(result, filenames_batch):
                    if resize: img = img.resize(resize)
                    if abs_path: img.save(f)
                    else:
                        if f.endswith(".png") or f.endswith(".jpg") or f.endswith(".jpeg"): img.save(f"{dest_folder}/{f}")
                        else: img.save(f"{dest_folder}/{f}.png")

            total_time = time.time() - start
            print(f"Total time: {total_time}s, Average time per image: {total_time/len(metadata_split)}s")

    def generate_images(self, captions, filenames, batch_size, steps, size, gs, seed=None, dest_folder="", makedir=False, abs_path=False, resize=None, seed_list=None):
        if seed_list is not None:
            print("GENERATE IMAGES | Using seed list from CSV")
        elif seed is not None:
            print("GENERATE IMAGES | Using provided seed")
        else:
            print("Either no seed provided or just starting initial seed, each process will thus use a different starting noise")
        inference_method = self.default_inference
        if self.parallelize:
            self._parallel_generation(captions, filenames, batch_size, steps, size, gs, seed, dest_folder, makedir, inference_method, abs_path, resize, seed_list=seed_list)
        else:
            self._sequential_generation(captions, filenames,  batch_size, steps, size, gs, seed, dest_folder, makedir, inference_method, abs_path, resize, seed_list=seed_list)