"""A module for a mixture density network layer

For more info on MDNs, see _Mixture Desity Networks_ by Bishop, 1994.
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.distributions import Categorical
from torch.distributions.normal import Normal
import math
from torch.utils.data import Dataset


ONEOVERSQRT2PI = 1.0 / math.sqrt(2 * math.pi)


class MyDataset(Dataset):
    def __init__(self, df, bucket, parent=None):
        if parent == None:
            self.parent = False
            self.x = torch.tensor(df[bucket].values, dtype=torch.float32)
        else:
            self.parent = True
            self.x = torch.tensor(df[parent].values, dtype=torch.float32)
            self.y = torch.tensor(df[bucket].values, dtype=torch.float32)

    def __getitem__(self, idx):
        if self.parent == True:
            return self.x[idx], self.y[idx]	
        else:
            return self.x[idx]

    def __len__(self):
        return len(self.x)


def mdn_loss(output, target):
    """Calculates the error, given the MoG parameters and the target
    """
    pi, sigma, mu = output
    model_dist = Normal(mu, sigma)
    target = target.unsqueeze(1).expand_as(sigma)   # 200,1,1
    log_prob = model_dist.log_prob(target)  # 200,1,1
    weighted_log_prob = torch.log(pi).unsqueeze(1).expand_as(sigma) + log_prob    # 200,1,1
    return -torch.logsumexp(weighted_log_prob, dim=1).mean()



def sample(output):
    """Draw samples from a MoG.
    """
    pi, sigma, mu = output
    # Choose which gaussian we'll sample from
    pis = Categorical(pi).sample().view(pi.size(0), 1, 1)
    gaussian_noise = torch.randn(
        (sigma.size(2), sigma.size(0)), requires_grad=False)
    variance_samples = sigma.gather(1, pis).detach().squeeze()
    mean_samples = mu.detach().gather(1, pis).squeeze()
    return (gaussian_noise * variance_samples + mean_samples).transpose(0, 1)


class MDN(nn.Module):
    def __init__(self, in_features, out_features, num_gaussians, hiddenlayers):
        super(MDN, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_gaussians = num_gaussians

        self.hiddenlayers = hiddenlayers
        self.hidden = nn.Sequential(
            nn.Linear(in_features, hiddenlayers),
            nn.Tanh(),
        )

        self.pi = nn.Sequential(
            nn.Linear(hiddenlayers, num_gaussians),
            nn.Softmax(dim=-1)
        )
        self.sigma = nn.Linear(hiddenlayers, out_features * num_gaussians)
        self.mu = nn.Linear(hiddenlayers, out_features * num_gaussians)
        

    def forward(self, minibatch):
        minibatch = self.hidden(minibatch)

        pi = self.pi(minibatch) 
        sigma = torch.exp(self.sigma(minibatch))
        sigma = sigma.view(-1, self.num_gaussians, self.out_features)
        mu = self.mu(minibatch)
        mu = mu.view(-1, self.num_gaussians, self.out_features)
        return pi, sigma, mu

#%%
# customised models for real data
class NumericalOutput(nn.Module):
    '''1~2 input, 1~2 output (numerical)'''
    def __init__(self, num_input, num_output, hidden_nodes):
        super(NumericalOutput, self).__init__()
        nodes_list = [num_input] + hidden_nodes + [num_output]

        self.num_hidden = len(hidden_nodes)
        self.linear_list = nn.ModuleList([
                        nn.Linear(nodes_list[i-1], nodes_list[i]) 
                        for i in range(1, len(nodes_list)) ])
        self.hidden_list = nn.ModuleList([
                        nn.ReLU() for i in range(1, len(nodes_list)-1) ])
    
    def forward(self, x):
        for i in range(self.num_hidden+1):
            x = self.linear_list[i](x)
            if i < self.num_hidden:
                x = self.hidden_list[i](x)
        return x


class BinaryOutput(nn.Module):
    '''b schoolsup school 	p age
    (binary, binary)	numerical'''
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(BinaryOutput, self).__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, hidden_dim)
        self.hidden1 = nn.ReLU()
        self.hidden2 = nn.ReLU()
        self.binary_output = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        out = self.hidden1(self.layer1(x))
        out = self.hidden2(self.layer2(out))
        # binary_out = torch.sigmoid(self.binary_output(out))
        binary_out = self.binary_output(out)
        return binary_out


def binarySample(output):
    '''output = binary_out'''
    output = torch.sigmoid(output)
    binary = torch.bernoulli(output)    
    return binary


class BinaryNumericalOutput(nn.Module):
    def __init__(self, in_features, out_features, num_gaussians, hiddenlayers):
        '''out_features =1'''
        super(BinaryNumericalOutput, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_gaussians = num_gaussians

        self.hiddenlayers = hiddenlayers
        self.hidden = nn.Sequential(
            nn.Linear(in_features, hiddenlayers),
            nn.Tanh(),
        )

        self.binary_output = nn.Linear(hiddenlayers, 1)

        self.pi = nn.Sequential(
            nn.Linear(hiddenlayers, num_gaussians),
            nn.Softmax(dim=-1)
        )
        self.sigma = nn.Linear(hiddenlayers, self.out_features * num_gaussians)
        self.mu = nn.Linear(hiddenlayers, self.out_features * num_gaussians)


    def forward(self, minibatch):
        minibatch = self.hidden(minibatch)
        
        binary_out = self.binary_output(minibatch)
        
        pi = self.pi(minibatch) 
        sigma = torch.exp(self.sigma(minibatch))
        sigma = sigma.view(-1, self.num_gaussians, self.out_features)
        mu = self.mu(minibatch)
        mu = mu.view(-1, self.num_gaussians, self.out_features)

        return binary_out, (pi, sigma, mu)


def binaryNumSample(output):
    '''output = (binary, (pi, sigma, mu))'''
    binary, numerical = output
    binary = torch.bernoulli(torch.sigmoid(binary))  
    numerical = sample(numerical)
    return torch.cat([binary, numerical], dim=-1)