import torch
import torchaudio
import numpy as np

import matplotlib.pyplot as plt

from tqdm.auto import tqdm
from tqdm import trange
import math
from transformers import (
    ClapModel,
    ClapProcessor,
    SpeechT5Processor,
    SpeechT5ForTextToSpeech,
    SpeechT5HifiGan,
)
from diffusers import (
    AutoencoderKL, 
    DDIMScheduler, 
)
from diffusers.utils import randn_tensor
from huggingface_hub import hf_hub_download

from .modules import UNetWrapper, TextEncoder, Text_align, Text_align_bind
from .diffusers.models.unets.unet_2d_condition import UNet2DConditionModel

import random
import librosa.display
import librosa
def sigmoid(x):
  return 1 / (1 + math.exp(x))


class VoiceLDMPipeline():
    def __init__(
        self,
        model_config = None,
        ckpt_path = None,
        device = None,
    ):
        if model_config is None:
            model_config = "m"

        self.noise_scheduler = DDIMScheduler.from_pretrained("cvssp/audioldm-m-full", subfolder="scheduler")
        self.vae = AutoencoderKL.from_pretrained("cvssp/audioldm-m-full", subfolder="vae").eval()
        self.vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm-m-full", subfolder="vocoder").eval()
        self.clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused").eval()
        self.clap_processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused")
        self.text_encoder = TextEncoder(SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")).eval()
        self.text_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")



        unet = UNet2DConditionModel(
            sample_size = 128,
            in_channels = 8,
            out_channels = 8,
            down_block_types = (
                "DownBlock2D",
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
            ),
            mid_block_type = "UNetMidBlock2DCrossAttn",
            up_block_types = (
                "CrossAttnUpBlock2D", 
                "CrossAttnUpBlock2D", 
                "CrossAttnUpBlock2D",
                "UpBlock2D",
            ),
            only_cross_attention = False,
            block_out_channels = [128, 256, 384, 640] if model_config == "s" else [192, 384, 576, 960],
            layers_per_block = 2,
            cross_attention_dim = 768,
            class_embed_type = 'simple_projection',
            projection_class_embeddings_input_dim = 512,
            class_embeddings_concat = True,
        )

        if ckpt_path is None:
            ckpt_path = hf_hub_download(
                repo_id=f"glory20h/voiceldm-{model_config}",
                filename=f"voiceldm-{model_config}.ckpt"
            )

       
        # TODO: Get checkpoints
        def load_ckpt(model, ckpt_path):
            ckpt = torch.load(ckpt_path, map_location="cpu")
            model.load_state_dict(ckpt)
            return model


        self.model = load_ckpt(UNetWrapper(unet, self.text_encoder), ckpt_path)

        self.device = device
        self.vae.to(device)
        self.vocoder.to(device)
        self.clap_model.to(device)
        self.text_encoder.to(device)
        self.model.to(device)
        self.model.eval()
        

    def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
        vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        shape = (
            batch_size,
            num_channels_latents,
            height // vae_scale_factor,
            self.vocoder.config.model_in_dim // vae_scale_factor,
        )
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )
        
        if latents is None:
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            latents = latents.to(device)

        latents = latents * self.noise_scheduler.init_noise_sigma
        return latents
    

    #This part contributes to preparing FIFO latent.
    def start_latent(self, latents, vid_len, time_ind):
        video = latents
        latents_list=[]
        #Buffer zone
        for i in range(vid_len//2):
            alpha = self.noise_scheduler.alphas_cumprod[0]
            beta = 1 - alpha
            latents = alpha**(0.5) * video[:,:,[0]] + beta**(0.5) * torch.randn_like(video[:,:,[0]])
            latents_list.append(latents)
        
        #Curved denoising
        for j in range(len(time_ind)):
            k = time_ind[j]
            alpha = self.noise_scheduler.alphas_cumprod[k] # image -> noise
            beta = 1 - alpha
            frame_idx = j +vid_len//2
            try:
                latents = (alpha)**(0.5) * video[:,:,[frame_idx]] + (1-alpha)**(0.5) * torch.randn_like(video[:,:,[frame_idx]])
            except:
                import pdb; pdb.set_trace()
            latents_list.append(latents)

        latents = torch.cat(latents_list, dim=2)

        
        return latents

    def decode_latents(self, latents):
        latents = 1 / self.vae.config.scaling_factor * latents
        mel_spectrogram = self.vae.decode(latents).sample
        return mel_spectrogram

    def mel_spectrogram_to_waveform(self, mel_spectrogram):
        if mel_spectrogram.dim() == 4:
            mel_spectrogram = mel_spectrogram.squeeze(1)

        waveform = self.vocoder(mel_spectrogram)
        waveform = waveform.cpu().float()
        return waveform

    def normalize_wav(self, waveform):
        waveform = waveform - torch.mean(waveform)
        waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
        return waveform
    
    #Enque the random image in the end of the input.
    def shift_latents(self,latents, device, generator, dtype, ind):

        latents[:,:,:-1] = latents[:,:,1:].clone()
        latents[:,:,-1] = torch.randn_like(latents[:,:,-1]) 

        return latents


    def __call__(
        self,
        desc_prompt,
        cont_prompt,
        audio_prompt = None,
        batch_size = 1,
        num_inference_steps = 100,
        audio_length_in_s = 10,
        do_classifier_free_guidance = True,
        guidance_scale = None,
        desc_guidance_scale = None,
        cont_guidance_scale = None,
        device=None,
        seed=None,
        **kwargs,
    ):
        if guidance_scale is None and desc_guidance_scale is None and cont_guidance_scale is None:
            do_classifier_free_guidance = False

        guidance = None
        if guidance_scale is None:
            guidance = "dual"
        else:
            guidance = "single"

        # description condition
        if audio_prompt is None:
            if do_classifier_free_guidance:
                if guidance == "dual":
                    desc_prompt = [desc_prompt] * 2 + [""] * 2
                    desc_null_prompt = ['clean speech']* 2 + [""] * 2
                if guidance == "single":
                    desc_prompt = [desc_prompt] + [""]
            clap_inputs = self.clap_processor(
                text=desc_prompt, 
                return_tensors="pt", 
                padding=True
            ).to(self.device)

            embeds = self.clap_model.text_model(**clap_inputs).pooler_output
            c_desc = self.clap_model.text_projection(embeds)


            clap_inputs2 = self.clap_processor(
                text=desc_null_prompt, 
                return_tensors="pt", 
                padding=True
            ).to(self.device)

            embeds2 = self.clap_model.text_model(**clap_inputs2).pooler_output
            c_desc_null = self.clap_model.text_projection(embeds2)
            

        else:
            audio_sample, sr = torchaudio.load(audio_prompt)
            if sr != 48000:
                audio_sample = torchaudio.functional.resample(audio_sample, orig_freq=sr, new_freq=48000)
            audio_sample = audio_sample[0]

            clap_inputs = self.clap_processor(audios=audio_sample, sampling_rate=48000, return_tensors="pt", padding=True).to(self.device)

            embeds = self.clap_model.audio_model(**clap_inputs).pooler_output
            c_desc = self.clap_model.audio_projection(embeds)

            if do_classifier_free_guidance:
                clap_inputs = self.clap_processor(text=[""], return_tensors="pt", padding=True).to(self.device)
                uncond_embeds = self.clap_model.text_model(**clap_inputs).pooler_output
                uc_desc = self.clap_model.text_projection(uncond_embeds)
                
                if guidance == "dual":
                    c_desc = torch.cat((c_desc, c_desc, uc_desc, uc_desc))
                if guidance == "single":
                    c_desc = torch.cat((c_desc, uc_desc))
                
        # content condition
        cont_prompt_list=[]
        cont_tokens_list=[]
        cont_embed_list=[]
        cont_embed_mask_list=[]

        cont_prompt_voice = cont_prompt

        #Separate long content prompt by sentence units(., ?, !).
        if len(cont_prompt.split(".")) !=1:
            sent_list = cont_prompt.split(".")
            for ind in range(len(sent_list)-1):
                cont_prompt_list.append(cont_prompt.split(".")[ind])

        else:
            cont_prompt_list.append(cont_prompt)

          
        if do_classifier_free_guidance:
            if guidance == "dual":
                for ind, item in enumerate(cont_prompt_list):
                    cont_prompt = ([item] + ["_"]) * 2
                    cont_tokens = self.text_processor(
                        text=cont_prompt, 
                        padding=True,
                        truncation=True,
                        max_length=1000,
                        return_tensors="pt"
                    ).to(self.device)
                    cont_embed = self.text_encoder(cont_tokens)
                    cont_embed_mask = cont_tokens.attention_mask.unsqueeze(1)

                    cont_tokens_list.append(cont_tokens)
                    cont_embed_list.append(cont_embed)
                    cont_embed_mask_list.append(cont_embed_mask)
            
            
            
            if guidance == "single":
                cont_prompt = [cont_prompt] + ["_"]


        cont_prompt_null = (["_"] + ["_"]) * 2
        cont_tokens_null = self.text_processor(
            text=cont_prompt_null, 
            padding=True,
            truncation=True,
            max_length=1000,
            return_tensors="pt"
        ).to(self.device)
        cont_embed_null = self.text_encoder(cont_tokens_null)
        cont_embed_mask_null = cont_tokens_null.attention_mask.unsqueeze(1)


        c_cont_list=[]
        cont_embed_mask_list_final =[]
        with torch.no_grad():
            for i in range(len(cont_embed_list)):
                c_cont, cont_embed_mask = self.model.durator(cont_embed_list[i], cont_embed_mask_list[i])
                c_cont_list.append(c_cont)
                cont_embed_mask_list_final.append(cont_embed_mask)
            c_cont_null, cont_embed_mask_null = self.model.durator(cont_embed_null, cont_embed_mask_null)
          
        vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
        height = int(audio_length_in_s * 1.024 / vocoder_upsample_factor)
        original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)

      
        self.noise_scheduler.set_timesteps(num_inference_steps)
        timesteps = self.noise_scheduler.timesteps

        # prepare latent variables
        num_channels_latents = self.model.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size,
            num_channels_latents,
            height,
            c_desc.dtype,
            device=device,
            generator=torch.manual_seed(seed) if seed else None,
            latents=None,
        )

   
        # prepare extra step kwargs
        extra_step_kwargs = {
            'eta': 0.0,
            'generator': torch.manual_seed(seed) if seed else None,
        }

        # new generated audio sample length 
        original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
        new_video_length = 256
        vid_len = 64
        
        
        
        # Original VoiceLDM inference
        
        with torch.no_grad():
            num_warmup_steps = len(timesteps) - num_inference_steps * self.noise_scheduler.order
            for i, t in enumerate(tqdm(timesteps)):
                if guidance == "dual":
                    latent_model_input = torch.cat([latents] * 4) if do_classifier_free_guidance else latents
                if guidance == "single":
                    latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                

                latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)

                c_cont  =  c_cont_null 
                cont_embed_mask = cont_embed_mask_null 
                
                noise_pred = self.model(
                    latent_model_input, 
                    t, 
                    c_cont,
                    cont_embed_mask,
                    c_desc,
                ).cuda()

                
                # perform guidance
                if do_classifier_free_guidance:
                    if guidance == "dual":
                        n1, n2, n3, n4 = noise_pred.chunk(4)
                        noise_pred = n1 + desc_guidance_scale * (n2 - n4) + cont_guidance_scale * (n3 - n4)
                    if guidance == "single":
                        noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
                        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                
                latents = self.noise_scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample


            mel_spectrogram = self.decode_latents(latents)
       
            audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
            audio = audio[:, :original_waveform_length]
            audio = self.normalize_wav(audio)


        

        num_inference_steps=300 
        self.noise_scheduler.set_timesteps(num_inference_steps)
        timesteps = self.noise_scheduler.timesteps

        
        time_ind=[]
        
        for i, item in enumerate(timesteps):
            if i >100:
                if i %5==1:
                    time_ind.append(item)
            else:
                time_ind.append(item)
 

        timesteps = torch.tensor(time_ind)
        indices = np.arange(num_inference_steps)

        with torch.no_grad():
            timesteps = torch.flip(timesteps,  dims=(0,))
            fifo=[]
            
            latents = self.start_latent(latents, vid_len, timesteps)
            cont_guidance_scale_first = torch.zeros_like(latents) 

            cont_guidance_scale_first[:,:,:,:] = cont_guidance_scale_first[:,:,:,:]+7 
            cont_guidance_scale_1 = torch.empty_like(cont_guidance_scale_first).copy_(cont_guidance_scale_first)
            cont_guidance_scale = torch.empty_like(cont_guidance_scale_first).copy_(cont_guidance_scale_first)
            
            timesteps = np.concatenate([np.full((vid_len//2,), timesteps[0]), timesteps])
            indices = np.concatenate([np.full((vid_len//2,), 0), indices])
            
            end=0
            point=[]
     
            for ind, cont in enumerate(c_cont_list):
                end += (int(cont.shape[1])+45) 
                point.append(end)
                if ind == (len(c_cont_list)-1) :
                    break
                end +=20 

            ind = 0
            prev_po=0


            new_video_length =point[-1] +70
            for i in trange(new_video_length, desc="sampling"):
                desc_guidance_scale = torch.zeros_like(latents) 
                desc_guidance_scale[:,:,:,:] = desc_guidance_scale[:,:,:,:]+5 

                if i ==0:
                    cont_guidance_scale = cont_guidance_scale

                if i < end:
                    for num, po in enumerate(point):
                        if i < po:
                      
                            c_cont =  c_cont_list[num] 
                            cont_embed_mask = cont_embed_mask_list[num] 
                            cont_guidance_scale = torch.empty_like(cont_guidance_scale_1).copy_(cont_guidance_scale_1)

                            #for qkv sharing
                            ind +=1
                            if ind >199:
                                ind =1
                            break
                            
                
                else:
                    cont_guidance_scale =-1
                    ind+=1
                    if ind > 199:
                        ind =1


                ts = torch.LongTensor(timesteps).unsqueeze(0).unsqueeze(0)

                if guidance == "dual":
                    latent_model_input = torch.cat([latents] * 4) if do_classifier_free_guidance else latents

                if guidance == "single":
                    latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                

                #This part will denote "Guidance Alternation Methods".
                if num%2 ==0:
                    noise_pred = self.model(latent_model_input, ts, c_cont, cont_embed_mask, c_desc, ind  )
                
                else:
                    latent_model_input = torch.stack([latent_model_input[0], latent_model_input[1],latent_model_input[3],latent_model_input[2]], dim=0)
                    noise_pred = self.model(latent_model_input, ts, c_cont, cont_embed_mask, c_desc, ind  )
                    latent_model_input = torch.stack([latent_model_input[0], latent_model_input[1],latent_model_input[3],latent_model_input[2]], dim=0)

                if do_classifier_free_guidance:
                    if guidance == "dual":
                        n1_global, n2_global, n3_global, n4_global = noise_pred.chunk(4)
                        noise_pred = n1_global + desc_guidance_scale * (n2_global - n4_global) + cont_guidance_scale * (n3_global - n4_global)
                    

                    if guidance == "single":
                        noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
                        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
                

                output_latents = self.noise_scheduler.step(noise_pred, ts, latents, **extra_step_kwargs).prev_sample

                latents= output_latents
                del output_latents

                first_frame_idx = vid_len // 2 
                frame = latents[:,:,[first_frame_idx]] #first_frame_idx except buffer zone. Dqueue!
                fifo.append(frame)
                    
                

                latents = self.shift_latents(latents,device=device,generator=torch.manual_seed(seed) if seed else None, dtype = c_desc.dtype, ind = ind)

            latents = torch.cat(fifo, dim=2)
            mel_spectrogram = self.decode_latents(latents)
            audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
            audio = self.normalize_wav(audio)
            
        return audio
        