import abc
from typing import List, Optional

import torch
import torch.optim as optim
import torch.nn as nn
import torch.distributions as td

from tqdm import tqdm


class BasicNCSNModule(nn.Module):
    def __init__(self, n_dim: int, n_hidden: int = 10):
        super().__init__()
        self.linear1 = nn.Linear(n_dim, n_hidden)  # Distribution scale is the additional input
        self.act1 = nn.ReLU()
        self.linear2 = nn.Linear(n_hidden, n_hidden)
        self.act2 = nn.ReLU()
        self.linear3 = nn.Linear(n_hidden, n_dim)

    def forward(self, x, sigma):
        x = self.act1(self.linear1(x))
        x = self.act2(self.linear2(x))
        x = self.linear3(x)
        x = x / sigma
        return x


class ScoreBasedGenerativeModel(abc.ABC):
    def __init__(self, device=torch.device('cpu')):
        self.model: Optional[nn.Module] = None  # Score estimation model
        self.device = device

    @abc.abstractmethod
    def loss(self, x: torch.Tensor, n_samples: int, **kwargs) -> torch.Tensor:
        """
        Compute loss for the score estimation model.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def train(self, x_train: torch.Tensor, n_epochs: int, **kwargs):
        """
        Train the score estimation model.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def compute_score(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Compute score for an arbitrary data point given the score estimation model.
        """
        raise NotImplementedError


class NoiseConditionalScoreNetwork(ScoreBasedGenerativeModel):
    def __init__(self, n_dim: int, scales: torch.Tensor, base_model: nn.Module = None, **kwargs):
        """
        Noise conditional score network class.
        Lets us obtain score estimates for arbitrary data points, intended for use in DLA.

        Note: sampling is possible, but currently not implemented, because DLA only needs score estimates.

        Reading:
        * Generative Modeling by Estimating Gradients of the Data Distribution, https://arxiv.org/abs/1907.05600
        * Score-Based Generative Modeling through Stochastic Differential Equations, https://arxiv.org/abs/2011.13456

        :param n_dim: number of data point dimensions.
        :param scales: non-decreasing list of standard deviations for the normal distribution.
        :param base_model: score estimation model. If None, a feedforward ReLU network is used.
        """
        super().__init__(**kwargs)
        assert list(scales) == sorted(list(scales))
        self.n_dim = n_dim
        self.scales = scales.to(self.device)

        if base_model is None:
            self.model = BasicNCSNModule(n_dim=n_dim)
        else:
            self.model = base_model
        self.model = self.model.to(self.device)

    def loss(self, x: torch.Tensor, n_samples: int = 10, **kwargs) -> torch.Tensor:
        """
        Compute loss as in the Noise Conditional Score Network.
        Equation 1 in https://arxiv.org/abs/2011.13456.

        :param x: data points with shape (n, n_dim).
        :param n_samples: number of samples to take from the neighborhood of each provided data point.
        """
        total = torch.tensor([0.0], device=self.device)
        for scale in self.scales:
            e_data = 0.0
            for i in range(len(x)):
                neighborhood_dist = td.Normal(loc=x[i], scale=scale)
                xs_tilde = neighborhood_dist.sample((n_samples,)).to(self.device)
                score_estimate = self.model(xs_tilde, scale)
                xs_tilde.requires_grad_(True)
                log_prob = neighborhood_dist.log_prob(xs_tilde)
                grad_x = torch.autograd.grad(torch.sum(log_prob), xs_tilde)[0]
                xs_tilde.detach_()
                assert score_estimate.shape == grad_x.shape
                e_data += torch.mean((score_estimate - grad_x) ** 2)
            e_data /= len(x)
            total += (scale ** 2) * e_data
        assert total.numel() == 1
        return total

    def train(self,
              x_train: torch.Tensor,
              n_epochs: int,
              n_samples: int = 10,
              lr: float = 1e-3,
              x_val: torch.Tensor = None,
              use_tqdm: bool = False,
              **kwargs):
        """
        Train the NCSN model. Afterwards, the model is accessible via `self.model`.

        :param x_train: train data, points in data space with shape (n_train, n_dim).
        :param n_epochs: number of epochs to train for.
        :param n_samples: number of neighborhood samples for each data point used to compute the loss.
        :param lr: learning rate for Adam.
        :param x_val: validation data, points in data space with shape (n_val, n_dim).
        :param use_tqdm: use tqdm.
        :return: dictionary with training and validation loss histories.
        """
        optimizer = optim.Adam(self.model.parameters(), lr=lr)

        train_loss_history = []
        val_loss_history = []

        pbar = tqdm(range(n_epochs)) if use_tqdm else range(n_epochs)
        for _ in pbar:
            optimizer.zero_grad()
            loss = self.loss(x=x_train, n_samples=n_samples)

            train_loss_history.append(float(loss))
            if x_val is not None:
                val_loss = self.loss(x=x_val, n_samples=n_samples)
                val_loss_history.append(float(val_loss))
            pbar.set_postfix_str(
                f'Train loss: {float(train_loss_history[-1]):.4f}, '
                f'Val loss: {float(val_loss) if x_val is not None else 0.0:.4f}'
            )

            loss.backward()
            optimizer.step()

        history = {
            'train_loss': train_loss_history,
            'val_loss': val_loss_history
        }
        return history

    @torch.no_grad()
    def compute_score(self, x: torch.Tensor, scale:float=None, **kwargs) -> torch.Tensor:
        """
        Compute score for new data.
        According to https://arxiv.org/abs/2011.13456, Section 2.1, this means using NCSN at the smallest scale.
        Note that the score estimate is approximate.

        :param x: data points with shape (n, n_dim).
        :param scale: scale of the normal distribution to use. The closer to 0, the more accurate the score estimate.
        :return: tensor of scores for each data point with shape (n, n_dim).
        """
        if scale is not None:
            assert scale > 0
            return self.model(x, scale)
        return self.model(x, self.scales[0])

    @torch.no_grad()
    def sample(self,
               n_samples: int,
               chain_length: int = 100,
               eps: float = 1e-5,
               use_tqdm: bool = False,
               return_history: bool = True):
        """
        Sample from the probability flow given by NCSN.
        The Langevin MCMC algorithm is described in https://arxiv.org/abs/2011.13456, Section 2.1.

        :param n_samples: number of samples to take.
        :param chain_length: length of the Langevin MCMC chain for each scale, for each data point.
        :param eps: step size for Langevin MCMC.
        :param use_tqdm: use tqdm.
        :param return_history: return deltas (norm of change in x during Langevin MCMC).
        """
        x = td.Normal(loc=0.0, scale=self.scales[-1]).sample((n_samples, self.n_dim)).to(self.device).float()
        pbar = tqdm(reversed(self.scales)) if use_tqdm else reversed(self.scales)
        deltas = []
        score_pred_norms = []
        for scale in pbar:
            alpha = eps * scale ** 2 / (self.scales[-1] ** 2)
            for _ in range(chain_length):
                score_pred = self.model(x, scale)
                delta = alpha * score_pred + ((2 * alpha) ** 0.5) * torch.randn_like(x, device=self.device)
                x = x - delta
                deltas.append(float(torch.linalg.norm(delta)))
                score_pred_norms.append(float(torch.linalg.norm(score_pred)))

        # Denoising step to remove the influence of the last noise scale
        x = x + (self.scales[-1] ** 2) * self.model(x, self.scales[-1])

        history = {
            'deltas': deltas,
            'score_norms': score_pred_norms
        }
        if return_history:
            return x, history
        return x
