from __future__ import print_function
import argparse
import time
import torch
import torch.utils.data
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import math
import random
import os
import datetime
import sys
from torch.autograd import Variable

sys.path.append("./")
sys.path.append("../integer_discrete_flows")
sys.path.append("../EinsumNetworks/src")

from EinsumNetwork import Graph, EinsumNetwork


def main():
    num_var = 2
    num_repetitions = 1
    num_input_distributions = 4
    num_sums = 4
    
    graph = Graph.random_binary_trees(num_var = num_var, depth = int(np.ceil(np.log(num_var))), 
                                      num_repetitions = num_repetitions)
    einet_args = EinsumNetwork.Args(num_var = num_var, num_dims=1, num_classes=1, 
                                    num_input_distributions=num_input_distributions, 
                                    num_sums=num_sums,
                                    exponential_family=EinsumNetwork.LogisticArray,
                                    exponential_family_args={'scale_min': 1e-4, 'scale_max': 1.0},
                                    online_em_frequency=1,
                                    online_em_stepsize=0.05,
                                    uniform_params = True)
    
    einet = EinsumNetwork.EinsumNetwork(graph, einet_args)
    einet.initialize()
    
    x = torch.rand([32, num_var])
    mean = torch.rand([32, num_var, num_repetitions * num_input_distributions])
    logscale = torch.rand([32, num_var, num_repetitions * num_input_distributions])
    
    ll_sample = einet.forward_with_grad2(x, mean, logscale)
    
    lls = baseline(x, mean, logscale)
    
    print(lls.size())
    print(ll_sample.size())
    
    print(torch.exp(ll_sample[0, 0]))
    print(torch.exp(lls[0,0,:]).mean() * torch.exp(lls[0,1,:]).mean())
    
    
def baseline(x, mean, logscale):
    
    def log_min_exp(a, b, epsilon=1e-8):
        y = a + torch.log(1 - torch.exp(b - a) + epsilon)
        return y
    
    def log_discretized_logistic(xlow, xhigh):
        logp = log_min_exp(
            F.logsigmoid(xhigh),
            F.logsigmoid(xlow))
        return logp
    
    x1 = (x.unsqueeze(-1) - mean - 0.5 / 256.0) / torch.exp(logscale)
    x2 = (x.unsqueeze(-1) - mean + 0.5 / 256.0) / torch.exp(logscale)
    
    return log_discretized_logistic(x1, x2)
    


if __name__ == "__main__":
    main()