import torch
import torch.nn as nn
from distributions import bijector


class ActNormFlow(bijector.Bijector):
    def __init__(self, num_features, eps=1e-6, dims=-1):
        self.dims=-1
        super(ActNormFlow, self).__init__(dims=dims)
        self.actnorm = ActNorm(num_features, eps, frozen=True)

    def forward(self, x):
        return self.actnorm(x)

    def inverse(self, y):
        x = y * torch.exp(self.actnorm.log_scale) + self.actnorm.shift
        return x

    def inverse_log_det_jacobian(self, y):
        ldji = self.actnorm.log_scale.sum(self.dims)
        ldji = torch.zeros_like(y.detach().sum(dim=self.dims)) + ldji
        return ldji

    def forward_and_invlogdet(self, x):
        y = self.forward(x)
        return y, self.inverse_log_det_jacobian(y)

    def inverse_and_invlogdet(self, y):
        x = self.inverse(y)
        return x, self.inverse_log_det_jacobian(x)


class ActNorm(nn.Module):
    '''
    Base class for activation normalization [1].

    References:
        [1] Glow: Generative Flow with Invertible 1×1 Convolutions,
            Kingma & Dhariwal, 2018, https://arxiv.org/abs/1807.03039
    '''
    def __init__(self, num_features, data_dep_init=True, eps=1e-5,
                 frozen=False):
        super(ActNorm, self).__init__()
        self.num_features = num_features
        self.data_dep_init = data_dep_init
        self.eps = eps
        self.initialised = False
        if not frozen:
            self.shift = nn.Parameter(torch.zeros(self.num_features))
            self.log_scale = nn.Parameter(torch.zeros(self.num_features))
        else:
            self.register_buffer('shift', torch.zeros(self.num_features))
            self.register_buffer('log_scale', torch.zeros(self.num_features))


    def data_init(self, x):
        if not self.initialised:
            print("Initialising", self)
            with torch.no_grad():
                x_mean, x_std = self.compute_stats(x.flatten(0, -2))
                self.shift.data = x_mean
                self.log_scale.data = torch.log(x_std + self.eps)
            self.initialised = True

    def compute_stats(self, x):
        '''Compute x_mean and x_std'''
        x_mean = torch.mean(x, dim=0)
        x_std = torch.std(x, dim=0)
        return x_mean, x_std

    def forward(self, x):
        if self.training:
            self.data_init(x)
        y = (x - self.shift) * torch.exp(-self.log_scale)
        return y
