import numpy as np
import torch
from torch import nn
from torch.nn import Module
from . import log_normal, bijector
from torch.nn import functional as F

DETERMINISTIC = False
def set_deterministic(val):
    global DETERMINISTIC
    DETERMINISTIC = val

def get_deterministic():
    global DETERMINISTIC
    return DETERMINISTIC


class ConditionalGaussian(Module):

    def __init__(self, input_size, conditioned_size=0):
        super(ConditionalGaussian, self).__init__()
        self.input_size = input_size
        self.flow = bijector.ScaleShift(dims=-1)
        self.transform = nn.Linear(conditioned_size, input_size * 2)
        torch.nn.init.zeros_(self.transform.weight)
        torch.nn.init.zeros_(self.transform.bias)

    def gaussian_params(self, conditioned):
        log_w, b = self.transform(conditioned).chunk(2, dim=-1)
        return log_w, b

    def forward(self, z, conditioned):
        log_w, b = self.gaussian_params(conditioned)
        eps = self.flow.inverse(z, log_w, b)
        ldji = self.flow.inverse_log_det_jacobian(z, log_w, b)
        return log_normal(eps).sum(dim=-1) + ldji

    def sample_and_logprob(self, conditioned, eps=None):
        if eps is None:
            eps = torch.randn(conditioned.size()[:-1] + (self.input_size,),
                              device=conditioned.device)
            if DETERMINISTIC:
                eps = eps * 0.

        log_w, b = self.gaussian_params(conditioned)
        z, ldji = self.flow.forward_and_invlogdet(eps, log_w, b)
        log_q_z0 = log_normal(eps).sum(-1)
        log_q_z = log_q_z0 + ldji
        return z, log_q_z


class TransformedGaussian(ConditionalGaussian):

    def __init__(self, input_size, flow: bijector.Bijector):
        super(ConditionalGaussian, self).__init__()
        self.input_size = input_size
        self.flow = flow

    def forward(self, z, conditioned):
        eps, ldji = self.flow.inverse_and_invlogdet(z, conditioned)
        return log_normal(eps).sum(dim=-1) + ldji

    def sample_and_logprob(self, conditioned, eps=None):
        if eps is None:
            eps = torch.randn(conditioned.size()[:-1] + (self.input_size,),
                              device=conditioned.device)
            if DETERMINISTIC:
                eps = eps * 0

        z, ldji = self.flow.forward_and_invlogdet(eps, conditioned)
        log_q_z0 = log_normal(eps).sum(dim=-1)
        log_q_z = log_q_z0 + ldji
        return z, log_q_z


class IAF(Module):
    # Fixed sized hidden layers for now.
    def __init__(self, input_size, conditioned_size=0,
                 layers=1, hidden_upscale=8, activation=nn.ELU()):

        super(IAF, self).__init__()
        self.flow = bijector.Sequential(
            bijector.PermuteDimensions(torch.from_numpy(np.random.permutation(input_size))),
            bijector.IAF(out_features=input_size, in_features=conditioned_size,
                         n_layers=3, block_size=8),
            bijector.PermuteDimensions(torch.from_numpy(np.random.permutation(input_size))),
            bijector.IAF(out_features=input_size, in_features=conditioned_size,
                         n_layers=3, block_size=8),
            bijector.PermuteDimensions(torch.from_numpy(np.random.permutation(input_size))),
            bijector.IAF(out_features=input_size, in_features=conditioned_size,
                         n_layers=3, block_size=8),
            bijector.PermuteDimensions(torch.from_numpy(np.random.permutation(input_size))),
            bijector.IAF(out_features=input_size, in_features=conditioned_size,
                         n_layers=3, block_size=8),
        )
        self.input_size = input_size

    def sample_and_logprob(self, conditioned, eps=None):
        if eps is None:
            eps = torch.randn(*(conditioned.size()[:-1]), self.input_size)
            if DETERMINISTIC:
                eps = eps * 0
            if conditioned.is_cuda:
                eps = eps.cuda()
        z, ldji = self.flow.forward_and_invlogdet(eps, conditioned)
        log_q_z0 = log_normal(eps).sum(-1)
        log_q_z = log_q_z0 + ldji
        return z, log_q_z

class NVPTransform(nn.Module):
    def __init__(self, size_1, size_2, out_size, hidden_size=None):
        super(NVPTransform, self).__init__()
        if hidden_size is None:
            self.hidden_size = max(size_1, size_2)
        else:
            self.hidden_size = hidden_size
        self.in_combine_1 = nn.Linear(size_1, self.hidden_size, bias=False)
        self.in_combine_2 = nn.Linear(size_2, self.hidden_size)
        nn.init.xavier_uniform_(self.in_combine_1.weight)
        nn.init.xavier_uniform_(self.in_combine_2.weight)
        nn.init.zeros_(self.in_combine_2.bias)
        self.out_transform = nn.Sequential(
            nn.ReLU(),
            nn.Linear(self.hidden_size, out_size * 2)
        )
        nn.init.zeros_(self.out_transform[-1].weight)
        nn.init.zeros_(self.out_transform[-1].bias)

    def forward(self, in1, in2=None):
        lin_hidden = (self.in_combine_1(in1) + self.in_combine_2(in2)) / 2.
        log_scale, shift = self.out_transform(lin_hidden).chunk(2, dim=-1)
        log_scale = F.softplus(log_scale + 32) - 32
        return log_scale, shift
