"""A module for a mixture density network layer

For more info on MDNs, see _Mixture Desity Networks_ by Bishop, 1994.
"""
import torch
import torch.nn as nn
from torch.distributions import Categorical
from torch.distributions.normal import Normal
import math
from torch.utils.data import Dataset

ONEOVERSQRT2PI = 1.0 / math.sqrt(2 * math.pi)

class MyDataset(Dataset):
    def __init__(self, df, bucket, parent=None):
        if parent == None:
            self.parent = False
            self.x = torch.tensor(df[bucket].values, dtype=torch.float32)
        else:
            self.parent = True
            self.x = torch.tensor(df[parent].values, dtype=torch.float32)
            self.y = torch.tensor(df[bucket].values, dtype=torch.float32)

    def __getitem__(self, idx):
        if self.parent == True:
            return self.x[idx], self.y[idx]	
        else:
            return self.x[idx]

    def __len__(self):
        return len(self.x)


class Model(nn.Module):
    def __init__(self, in_features, out_features, hiddenlayers):
        super(Model, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.hiddenlayers = hiddenlayers
        self.hidden = nn.Sequential(
            nn.Linear(in_features, hiddenlayers),
            nn.Tanh(),
        )
        self.output = nn.Linear(hiddenlayers, out_features)


    def forward(self, minibatch):
        minibatch = self.hidden(minibatch)

        return self.output(minibatch)


# try another architecture!!
class Model2(nn.Module):
    def __init__(self, num_cat_features=2, num_classes=94, embedding_size=10, hidden_size=32):
        super(Model2, self).__init__()

        # Embedding layers for categorical features
        self.embedding1 = nn.Embedding(4, embedding_size)
        self.embedding2 = nn.Embedding(3, embedding_size)

        # Fully connected layers
        self.fc1 = nn.Linear(embedding_size*num_cat_features, hidden_size)
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x1x2):
        x1 = self.embedding1(x1x2[:,:4])
        x2 = self.embedding2(x1x2[:,4:])

        x1 = x1.view(x1.size(0), -1)
        x2 = x2.view(x2.size(0), -1)

        x = torch.cat((x1, x2), dim=1)
        x = self.fc1(x)
        x = nn.ReLU()(x)
        x = self.fc2(x)

        return x


def mdn_loss(output, target):
    """Calculates the error, given the MoG parameters and the target
    """
    pi, sigma, mu = output
    model_dist = Normal(mu, sigma)
    target = target.unsqueeze(1).expand_as(sigma)   # 200,1,1
    log_prob = model_dist.log_prob(target)  # 200,1,1
    # weighted_log_prob = torch.log(pi).unsqueeze(1).expand_as(sigma) + log_prob    # 200,1,1
    weighted_log_prob = torch.log(pi).unsqueeze(2).expand_as(sigma) + log_prob    # 200,1,1
    return -torch.logsumexp(weighted_log_prob, dim=1).mean()


def sample(output):
    """Draw samples from a MoG.
    """
    pi, sigma, mu = output
    # Choose which gaussian we'll sample from
    pis = Categorical(pi).sample().view(pi.size(0), 1, 1)
    gaussian_noise = torch.randn(
        (sigma.size(2), sigma.size(0)), requires_grad=False)
    variance_samples = sigma.gather(1, pis).detach().squeeze()
    mean_samples = mu.detach().gather(1, pis).squeeze()
    return (gaussian_noise * variance_samples + mean_samples).transpose(0, 1)


class MDN(nn.Module):
    def __init__(self, in_features, out_features, num_gaussians, hiddenlayers):
        super(MDN, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_gaussians = num_gaussians

        self.hiddenlayers = hiddenlayers
        self.hidden = nn.Sequential(
            nn.Linear(in_features, hiddenlayers),
            nn.Tanh(),
        )

        self.pi = nn.Sequential(
            nn.Linear(hiddenlayers, num_gaussians),
            nn.Softmax(dim=-1)
        )
        self.sigma = nn.Linear(hiddenlayers, out_features * num_gaussians)
        self.mu = nn.Linear(hiddenlayers, out_features * num_gaussians)
        

    def forward(self, minibatch):
        minibatch = self.hidden(minibatch)

        pi = self.pi(minibatch) 
        sigma = torch.exp(self.sigma(minibatch))
        sigma = sigma.view(-1, self.num_gaussians, self.out_features)
        mu = self.mu(minibatch)
        mu = mu.view(-1, self.num_gaussians, self.out_features)
        return pi, sigma, mu