import numpy as np

import torch
from tqdm import tqdm
import json
import os

from wavegrad.base import BaseModule
from wavegrad.nn import WaveGradNN

class WaveGrad(BaseModule):
    def __init__(self, preprocess_config, model_config):
        super(WaveGrad, self).__init__()
        self.noise_schedule_is_set = False

        self.total_factor = np.product(model_config["wavegrad"]["factors"])
        assert self.total_factor == preprocess_config["preprocessing"]["stft"]["hop_length"], \
            """Total factor-product should be equal to the hop length of STFT."""
        self.nn = WaveGradNN(model_config)

    def set_new_noise_schedule(
        self,
        init=torch.linspace,
        init_kwargs = {'steps': 100, 'start': 1e-6, 'end': 1e-2}
    ):
        """
        Sets sampling noise schedule. WaveGrad supports variable noise schedules during inference.
        Thanks to the continuous noise level conditioning.
        :param init (callable function, optional): function which initializes betas
        :param init_kwargs (dict, optional): dict of arguments to be pushed to `init` function.
            Should always contain the key `steps` corresponding to the number of iterations to be done by the model.
            This is done so because `torch.linspace` has this argument named as `steps`.
        """
       assert 'steps' in list(init_kwargs.keys()), \
            '`init_kwargs` should always contain the key `steps` corresponding to the number of iterations to be done by the model.'
        n_iter = init_kwargs['steps']
        betas = init(**init_kwargs)
        alphas = 1 - betas
        alphas_cumprod = alphas.cumprod(dim=0)
        alphas_cumprod_prev = torch.cat([torch.FloatTensor([1]), alphas_cumprod[:-1]])
        alphas_cumprod_prev_with_last = torch.cat([torch.FloatTensor([1]), alphas_cumprod])
        
        self.register_buffer('betas', betas)
        self.register_buffer('alphas', alphas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        sqrt_alphas_cumprod = alphas_cumprod.sqrt()
        self.sqrt_alphas_cumprod_prev = alphas_cumprod_prev_with_last.sqrt().numpy()
        sqrt_recip_alphas_cumprod = (1 / alphas_cumprod).sqrt()
        sqrt_alphas_cumprod_m1 = (1 - alphas_cumprod).sqrt() * sqrt_recip_alphas_cumprod
    
        self.register_buffer('sqrt_alphas_cumprod', sqrt_alphas_cumprod)
        self.register_buffer('sqrt_recip_alphas_cumprod', sqrt_recip_alphas_cumprod)
        self.register_buffer('sqrt_alphas_cumprod_m1', sqrt_alphas_cumprod_m1)

        posterior_variance = betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)
        posterior_variance = torch.stack([posterior_variance, torch.FloatTensor([1e-20] * n_iter)])
        posterior_log_variance_clipped = posterior_variance.max(dim=0).values.log()

        posterior_mean_coef1 = betas * alphas_cumprod_prev.sqrt() / (1 - alphas_cumprod)
        posterior_mean_coef2 = (1 - alphas_cumprod_prev) * alphas.sqrt() / (1 - alphas_cumprod)

        self.register_buffer('posterior_log_variance_clipped', posterior_log_variance_clipped)
        self.register_buffer('posterior_mean_coef1', posterior_mean_coef1)
        self.register_buffer('posterior_mean_coef2', posterior_mean_coef2)
        
        self.n_iter = n_iter
        self.noise_schedule_kwargs = {'init': init, 'init_kwargs': init_kwargs}
        self.noise_schedule_is_set = True


    def sample_continuous_noise_level(self, batch_size, device):
        s = np.random.choice(range(1, self.n_iter + 1), size=batch_size)
        continuous_sqrt_alpha_cumprod = torch.FloatTensor(
            np.random.uniform(
                self.sqrt_alphas_cumprod_prev[s-1],
                self.sqrt_alphas_cumprod_prev[s],
                size=batch_size
            )
        ).to(device)
        return continuous_sqrt_alpha_cumprod.unsqueeze(-1)
    
    def q_sample(self, y_0, continuous_sqrt_alpha_cumprod=None, eps=None):
        """
        Efficiently computes diffusion version y_t from y_0 using a closed form expression:
            y_t = sqrt(alpha_cumprod)_t * y_0 + sqrt(1 - alpha_cumprod_t) * eps,
            where eps is sampled from a standard Gaussian.
        """
        batch_size = y_0.shape[0]
        continuous_sqrt_alpha_cumprod \
            = self.sample_continuous_noise_level(batch_size, device=y_0.device) \
                if isinstance(eps, type(None)) else continuous_sqrt_alpha_cumprod
        if isinstance(eps, type(None)):
            eps = torch.randn_like(y_0)
        # Closed form signal diffusion
        outputs = continuous_sqrt_alpha_cumprod * y_0 + (1 - continuous_sqrt_alpha_cumprod**2).sqrt() * eps
        
        return outputs

    def q_posterior(self, y_start, y, t):
        """
        Computes reverse (denoising) process posterior q(y_{t-1}|y_0, y_t, x)
        parameters: mean and variance.
        """
        posterior_mean = self.posterior_mean_coef1[t] * y_start + self.posterior_mean_coef2[t] * y
        posterior_log_variance_clipped = self.posterior_log_variance_clipped[t]

        return posterior_mean, posterior_log_variance_clipped

    def predict_start_from_noise(self, y, t, eps):
        """
        Computes y_0 from given y_t and reconstructed noise.
        Is needed to reconstruct the reverse (denoising)
        process posterior q(y_{t-1}|y_0, y_t, x).
        """
        return self.sqrt_recip_alphas_cumprod[t] * y - self.sqrt_alphas_cumprod_m1[t] * eps

    def p_mean_variance(self, texts, y, t, clip_denoised: bool):
        """
        Computes Gaussian transitions of Markov chain at step t
        for further computation of y_{t-1} given current state y_t and features.
        """
        batch_size = texts.shape[0]

        noise_level = torch.FloatTensor([self.sqrt_alphas_cumprod_prev[t+1]]).repeat(batch_size, 1).to(texts)
        eps_recon = self.nn(texts, y, noise_level)
        y_recon = self.predict_start_from_noise(y, t, eps_recon)

        if clip_denoised:
            y_recon.clamp_(-1.0, 1.0)
        
        model_mean, posterior_log_variance = self.q_posterior(y_start=y_recon, y=y, t=t)

    
        return model_mean, posterior_log_variance

    def compute_inverse_dynamics(self, texts, y, t, clip_denoised=True):
        """
        Computes reverse (denoising) process dynamics. Closely related to the idea of Langevin dynamics.
        :param texts (torch.Tensor): text features of shape [B, n_texts, T//hop_length]
        :param y (torch.Tensor): previous state from dynamics trajectory
        :param clip_denoised (bool, optional): clip signal to [-1, 1]
        :return (torch.Tensor): next state
        """
        model_mean, model_log_variance = self.p_mean_variance(texts, y, t, clip_denoised)
        eps = torch.randn_like(y) if t > 0 else torch.zeros_like(y)

        return model_mean + eps * (0.5 * model_log_variance).exp()
    
    def sample(self, texts, store_intermediate_states=False):
        """
        Samples speech waveform via progressive denoising of white noise with guidance of texts-epctrogram.
        :param texts (torch.Tensor): text features of shape [B, n_texts, T//hop_length]
        :param store_intermediate_states (bool, optional): whether to store dynamics trajectory or not
        :return ys (list of torch.Tensor) (if store_intermediate_states=True)
            or y_0 (torch.Tensor): predicted signals on every dynamics iteration of shape [B, T]
        """
        with torch.no_grad():
            device = next(self.parameters()).device
            batch_size, T = texts.shape[0], texts.shape[-1]
            ys = [torch.randn(batch_size, T*self.total_factor, dtype=torch.float32).to(device)]

            for t in tqdm(range(0, self.n_iter)[::-1]): 
                y_t = self.compute_inverse_dynamics(texts, y=ys[-1], t=t)
                ys.append(y_t)
        
            return ys if store_intermediate_states else ys[-1]

    def compute_loss(self, texts, y_0, continuous_sqrt_alpha_cumprod=None, eps=None):
        """
        Computes loss between GT Gaussian noise and reconstructed noise by model from diffusion process.
        :param texts (torch.Tensor): text features of shape [B, n_texts, T//hop_length] 
        :param y_0 (torch.Tensor): GT speech signals
        :return loss (torch.Tensor): loss of diffusion model
        """
        self._verify_noise_schedule_existence()

        batch_size = y_0.shape[0]
        
        continuous_sqrt_alpha_cumprod \
            = self.sample_continuous_noise_level(batch_size, device=y_0.device) if isinstance(eps, type(None)) else continuous_sqrt_alpha_cumprod
        
        eps = torch.randn_like(y_0)
        y_noisy = self.q_sample(y_0, continuous_sqrt_alpha_cumprod, eps)
        eps_recon = self.nn(texts, y_noisy, continuous_sqrt_alpha_cumprod)
        import pdb
        # pdb.set_trace()
        loss = torch.nn.L1Loss()(eps_recon, eps)

        return loss

    def forward(self, texts, store_intermediate_states=False):
        """
        Generates speech from given text.
        :param texts (torch.Tensor): text tensor of shape [1, n_texts, T//hop_length]
        :param store_intermediate_states (bool, optional):
            flag to set return tensor to be a set of all states of denoising process 
        """
        self._verify_noise_schedule_existence()
        
        return self.sample(
            texts, store_intermediate_states
        )
    
    def _verify_noise_schedule_existence(self):
        if not self.noise_schedule_is_set:
            raise RuntimeError(
                'No noise schedule is found. Specify your noise schedule '
                'by pushing arguments into `set_new_noise_schedule(...)` method. '
                'For example: '
                "`wavegrad.set_new_noise_level(init=torch.linspace, init_kwargs=\{'steps': 50, 'start': 1e-6, 'end': 1e-2\})`."
            )



    def naive_attack(self, texts, y0, attack_num=100, interval=10, store_intermediate_states=True):
        self._verify_noise_schedule_existence()
        intermediate_reverse = self.naive_reverse(texts, y0, attack_num, interval)
        intermediate_denoise = self.naive_denoise(texts, y0, intermediate_reverse, attack_num, interval)
        # del intermediate_reverse[-1]
        # del intermediate_denoise[0]
        intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
        return ((intermediate_reverse - intermediate_denoise)**2).flatten(2).sum(dim=-1)
 
    def naive_reverse(self, texts, y0, attack_num, interval):
        self._verify_noise_schedule_existence()
        intermediates = []
        terminal_step = interval * attack_num
        for i in range(0, terminal_step, interval):
            eps = torch.randn_like(y0)
            intermediates.append(eps)
        return intermediates

    def naive_denoise(self, texts, y0, intermediates, attack_num, interval):
        self._verify_noise_schedule_existence()
        intermediates_denoise = []
        terminal_step = interval * attack_num
        # for idx in range(len(intermediates)):
        for idx, step in enumerate(range(len(intermediates))):
            batch_size = y0.shape[0]
            eps = intermediates[idx]
            noise_level = torch.FloatTensor([self.alphas_cumprod[step]]).repeat(batch_size, 1).to(y0.device)
            xt_step = noise_level * y0 + (1 - noise_level**2).sqrt() * eps
            eps_recon = self.nn(texts, xt_step, noise_level)
            intermediates_denoise.append(eps_recon)
        return intermediates_denoise


    def secmi_attack(self, texts, y0, attack_num=100, interval=10, store_intermediate_states=True):
        self._verify_noise_schedule_existence()
        intermediate_reverse = self.ddim_reverse(texts, y0, attack_num, interval)
        intermediate_denoise = self.ddim_denoise(texts, y0, intermediate_reverse, attack_num, interval)
        #del intermediate_reverse[0]
        #del intermediate_denoise[0]
        intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
        return ((intermediate_reverse - intermediate_denoise)**2).flatten(2).sum(dim=-1)

    def ddim_reverse(self, texts, x0, attack_num, interval):
        intermediates = []
        terminal_step = interval * (attack_num - 1)
        x = x0
        batch_size = x0.shape[0]
        device = x0.device
        for idx, step in enumerate(range(0, terminal_step, interval)):
            noise_level = torch.FloatTensor([self.alphas_cumprod[idx]]).repeat(batch_size, 1).to(device)
            eps_recon = self.nn(texts, x, noise_level)
            xt_step = noise_level * x0 + (1 - noise_level**2).sqrt() * eps_recon
            intermediates.append(xt_step)
            x = xt_step
        return intermediates

    def ddim_denoise(self, texts, x0, intermediates, attack_num, interval):
        intermediates_denoise = []
        terminal_step = interval * (attack_num - 1)
        batch_size = x0.shape[0]
        device = x0.device
        for idx in range(len(intermediates)):
            x_t = intermediates[idx]
            noise_level = torch.FloatTensor([self.alphas_cumprod[idx]]).repeat(batch_size, 1).to(device)
            eps_recon = self.nn(texts, x_t, noise_level)
            x0_hat = (x_t - torch.sqrt(1 - noise_level) * eps_recon) / torch.sqrt(noise_level + 1e-5)
            x_t_hat = noise_level * x0_hat + (1 - noise_level**2).sqrt() * eps_recon
            intermediates_denoise.append(x_t_hat)
        return intermediates_denoise


    def pia_attack(self, texts, y0, attack_num=100, interval=10, store_intermediate_states=True):
        self._verify_noise_schedule_existence()
        intermediate_reverse = self.pia_reverse(texts, y0, attack_num, interval)
        intermediate_denoise = self.pia_denoise(texts, y0, intermediate_reverse, attack_num, interval)
        # del intermediate_reverse[-1]
        # del intermediate_denoise[0]
        intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
        return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)
    
    def pia_reverse(self, texts, y0, attack_num, interval):
        intermediates = []
        terminal_step = interval * (attack_num-1)
        batch_size = y0.shape[0]
        noise_level = torch.FloatTensor([self.alphas_cumprod[0]]).repeat(batch_size, 1).to(y0.device)
        eps_recon = self.nn(texts, y0, noise_level)

        for i, step in enumerate(reversed(range(0, terminal_step, interval))):
            intermediates.append(eps_recon)
        return intermediates

    def pia_denoise(self, texts, y0, intermediates, attack_num, interval):
        intermediates_denoise = []
        terminal_step = interval * (attack_num-1)
        batch_size = y0.shape[0]

        for i, step in enumerate(range(interval, terminal_step + interval, interval)):
            eps = intermediates[i]
            noise_level = torch.FloatTensor([self.alphas_cumprod[step]]).repeat(batch_size, 1).to(y0.device)
            xt_step = noise_level * y0 + (1 - noise_level**2).sqrt() * eps
            eps_back = self.nn(texts, xt_step, noise_level)
            intermediates_denoise.append(eps_back)
        return intermediates_denoise
    
    def pian_attack(self, texts, y0, attack_num=100, interval=10, store_intermediate_states=True):
        self._verify_noise_schedule_existence()
        intermediate_reverse = self.pian_reverse(texts, y0, attack_num, interval)
        intermediate_denoise = self.pian_denoise(texts, y0, intermediate_reverse, attack_num, interval)
        # del intermediate_reverse[-1]
        # del intermediate_denoise[0]
        intermediate_reverse, intermediate_denoise = torch.stack(intermediate_reverse), torch.stack(intermediate_denoise)
        return ((intermediate_reverse - intermediate_denoise).abs()**4).flatten(2).sum(dim=-1)
    
    def pian_reverse(self, texts, y0, attack_num, interval):
        intermediates = []
        terminal_step = interval * (attack_num-1)
        batch_size = y0.shape[0]
        noise_level = torch.FloatTensor([self.alphas_cumprod[0]]).repeat(batch_size, 1).to(y0.device)
        eps_recon = self.nn(texts, y0, noise_level)
        eps_recon = eps_recon / eps_recon.abs().mean(list(range(1, eps_recon.ndim)), keepdim=True) * (2 / torch.pi) ** 0.5

        for i, step in enumerate(reversed(range(0, terminal_step, interval))):
            intermediates.append(eps_recon)
        return intermediates

    def pian_denoise(self, texts, y0, intermediates, attack_num, interval):
        intermediates_denoise = []
        terminal_step = interval * (attack_num-1)
        batch_size = y0.shape[0]

        for i, step in enumerate(range(interval, terminal_step + interval, interval)):
            eps = intermediates[i]
            noise_level = torch.FloatTensor([self.alphas_cumprod[step]]).repeat(batch_size, 1).to(y0.device)
            xt_step = noise_level * y0 + (1 - noise_level**2).sqrt() * eps
            eps_back = self.nn(texts, xt_step, noise_level)
            
            intermediates_denoise.append(eps_back)
        return intermediates_denoise
    