import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from continuous.utils import pytorch_util as ptu
from continuous.utils.pytorch_core import PyTorchModule
from continuous.utils.system import reproduce

class VAEDensity(PyTorchModule):
    def __init__(self,
                 input_size,
                 code_dim=4,
                 beta=0.5,
                 lr=1e-3,
                 mlp_dim=32,
                 device=torch.device("cpu"),
                 seed=1,
                 ):
        """
        A simple VAE model. Copied and adjusted from Lisa's SMM code.

        Args:
            input_size: should be a tuple that specifies the input dimension, e.g., (5, 4)
            code_dim: dim of latent vector.
            beta: coefficient of KL loss.
            lr: learning rate
        """
        self.save_init_params(locals())
        super().__init__()

        self.device = device
        reproduce(seed)

        input_dim = input_size
        self.enc = nn.Sequential(
            nn.Linear(input_dim, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, mlp_dim),
            nn.ReLU(),
        ).to(self.device)

        self.enc_mu = nn.Linear(mlp_dim, code_dim).to(self.device)
        self.enc_logvar = nn.Linear(mlp_dim, code_dim).to(self.device)

        self.dec = nn.Sequential(
            nn.Linear(code_dim, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, input_dim),
        ).to(self.device)

        self.lr = lr
        self.beta = beta
        params = (list(self.enc.parameters()) +
                  list(self.enc_mu.parameters()) +
                  list(self.enc_logvar.parameters()) +
                  list(self.dec.parameters()))
        self.optimizer = optim.Adam(params, lr=self.lr)

    def get_output_for(self, aug_obs, sample=True):
        """
        Returns the log probability of the given observation.
        """
        obs = aug_obs
        with torch.no_grad():
            enc_features = self.enc(obs)
            mu = self.enc_mu(enc_features)
            logvar = self.enc_logvar(enc_features)

            stds = (0.5 * logvar).exp()
            if sample:
                epsilon = ptu.randn(*mu.size())
            else:
                epsilon = torch.ones_like(mu)
            code = epsilon * stds + mu

            obs_distribution_params = self.dec(code)
            log_prob = -1. * F.mse_loss(obs, obs_distribution_params,
                                        reduction='none')
            log_prob = torch.sum(log_prob, -1, keepdim=True)
        return log_prob.detach()

    def get_density(self, obs):
        """
        assume obs is a batch
        """
        obs = torch.FloatTensor(obs).to(self.device)
        log_prob = self.get_output_for(obs)
        log_prob = log_prob.cpu().numpy()
        return log_prob 

    def score_samples(self, obs):
        return self.get_density(obs)

    def update(self, aug_obs):
        obs = aug_obs

        enc_features = self.enc(obs)
        mu = self.enc_mu(enc_features)
        logvar = self.enc_logvar(enc_features)

        stds = (0.5 * logvar).exp()
        epsilon = ptu.randn(*mu.size())
        code = epsilon * stds + mu

        kle = -0.5 * torch.sum(
            1 + logvar - mu.pow(2) - logvar.exp(), dim=1
        ).mean()

        obs_distribution_params = self.dec(code)
        log_prob = -1. * F.mse_loss(obs, obs_distribution_params,
                                    reduction='elementwise_mean')

        loss = self.beta * kle - log_prob

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.cpu().item()

    def train(self, obs, batch_size=64, train_epoch=200):
        loss = 0
        for train_iter in range(train_epoch):
            idxes = np.random.choice(len(obs), size=batch_size, replace=False)
            train_batch = obs[idxes]
            train_batch = torch.FloatTensor(train_batch).to(self.device)
            loss += self.update(train_batch)
        loss /= train_epoch
        return loss