import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm.auto import tqdm


class SVDiffusion(object):
    def __init__(self, time_steps, unet, w, device=None, model_path=None):
        self.time_steps = time_steps
        self.w = w
        self.betas = self._linear_beta_schedule()

        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)  # coeff1
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)

        if device is not None:
            self.device = device
        else:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = unet.to(self.device)
        # self.channels = channels

    def _get_index_from_list(self, vals: torch.Tensor, t: torch.Tensor, x_shape: torch.Size) -> torch.Tensor:
        """ helper function to get index from list, considering batch dimension

        Args:
            vals (torch.Tensor): list of values
            t (torch.Tensor): timestep
            x_shape (torch.Size): shape of input image

        Returns:
            torch.Tensor: value at timestep t
        """
        batch_size = t.shape[0]  # batch_size
        out = vals.gather(-1, t.cpu())  # (batch_size, 1)
        return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

    def _linear_beta_schedule(self, start=0.0001, end=0.02) -> torch.Tensor:
        """ linear beta schedule
        Args:
            start (float, optional): beta at timestep 0. Defaults to 0.0001.
            end (float, optional): beta at last timestep. Defaults to 0.02.

        Returns:
            torch.Tensor: beta schedule
        """
        return torch.linspace(start, end, self.time_steps)

    def forward(self, x_0: torch.Tensor, t: torch.Tensor, type='forecast'):
        """ forward process of diffusion model
        Args:
            x_0 (torch.Tensor): input image
            t (torch.Tensor): timestep

        Returns:
            tuple[torch.Tensor, torch.Tensor]: noisy image and noise
        """
        noise = torch.randn_like(x_0).to(self.device)

        sqrt_alphas_cumprod_t = self._get_index_from_list(
            self.sqrt_alphas_cumprod, t, x_0.shape
        )
        sqrt_one_minus_alphas_cumprod_t = self._get_index_from_list(
            self.sqrt_one_minus_alphas_cumprod, t, x_0.shape
        )
        # mean + variance
        return sqrt_alphas_cumprod_t.to(self.device) * x_0.to(self.device) \
               + sqrt_one_minus_alphas_cumprod_t.to(self.device) * noise.to(self.device), \
               noise.to(self.device)
    
    @torch.no_grad()
    def add_noise(self, x_0, t):
        noise = torch.randn_like(x_0).to(self.device)

        sqrt_alphas_cumprod_t = self._get_index_from_list(
            self.sqrt_alphas_cumprod, t, x_0.shape
        )
        sqrt_one_minus_alphas_cumprod_t = self._get_index_from_list(
            self.sqrt_one_minus_alphas_cumprod, t, x_0.shape
        )
        return sqrt_alphas_cumprod_t.to(self.device) * x_0.to(self.device) \
               + sqrt_one_minus_alphas_cumprod_t.to(self.device) * noise.to(self.device)

    @torch.no_grad()
    def sample_timestep(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Calls the model to predict the noise in the image and returns
        the denoised image.
        Applies noise to this image, if we are not in the last step yet.
        """
        betas_t = self._get_index_from_list(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = self._get_index_from_list(
            self.sqrt_one_minus_alphas_cumprod, t, x.shape
        )
        sqrt_recip_alphas_t = self._get_index_from_list(self.sqrt_recip_alphas, t, x.shape)

        # Call model (current image - noise prediction)
        eps = self.model(x, t)
        # nonEps = self.model(x, t, torch.zeros_like(labels).to(labels.device))
        # eps = (1. + self.w) * eps - self.w * nonEps
        model_mean = sqrt_recip_alphas_t * (
                x - betas_t * eps / sqrt_one_minus_alphas_cumprod_t
        )

        posterior_variance_t = self._get_index_from_list(self.posterior_variance, t, x.shape)
        # print(t)

        if t[0] == 0:
            # As pointed out by Luis Pereira (see YouTube comment)
            # The t's are offset from the t's in the paper
            return model_mean
        else:
            noise = torch.randn_like(x)
            return model_mean + torch.sqrt(posterior_variance_t) * noise

    @torch.no_grad()
    def sampling(self, batch, VT, x_T: torch.Tensor) -> torch.Tensor:
        """ sampling process of diffusion model
        Args:
            x_T (torch.Tensor): input image (gaussian noise)

        Returns:
            torch.Tensor: denoised image
        """

        x = x_T
        for i in tqdm(reversed(range(self.time_steps)), desc="Sampling"):
            t = torch.full((x.shape[0],), i, dtype=torch.long, device=self.device)
            indices = torch.arange(x.shape[0]).unsqueeze(1).expand(-1, batch.yc.shape[1]).to(self.device)
            x = self.sample_timestep(x, t)
            
            x = torch.matmul(batch.right_embedding, torch.matmul(x, batch.left_embedding))

            pred_signal = torch.matmul(x, VT)
            VT_inverse = VT.transpose(2, 1)
            pred_signal[indices, batch.yc.to(self.device)] = batch.xc.to(self.device)
            x = torch.matmul(pred_signal, VT_inverse)
            x = torch.matmul(batch.left_embedding, torch.matmul(x, batch.right_embedding))
          
        return x

    @torch.no_grad()
    def sampling_sequence(self, batch, VT, model_path=None, original_data=None) -> np.ndarray:
        if model_path is not None:
            self.model.load_state_dict(torch.load(model_path, map_location=self.device)['unet'])
        self.model.eval()

        torch.manual_seed(0)
        torch.cuda.manual_seed(0)

        x_T = original_data.to(self.device)
        sampled_tensor = self.sampling(batch, VT, x_T)
        sampled_tensor = torch.matmul(batch.right_embedding, torch.matmul(sampled_tensor, batch.left_embedding))
        sampled_seq = sampled_tensor.squeeze().detach()#.cpu().numpy()
        return sampled_seq