from __future__ import print_function
import argparse
import time
import torch
import torch.utils.data
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import math
import random
import os
import datetime
import sys
from torch.autograd import Variable

sys.path.append("./")
sys.path.append("../integer_discrete_flows")
sys.path.append("../EinsumNetworks/src")

from EinsumNetwork import Graph, EinsumNetwork


def main():
    num_var = 100
    num_repetitions = 1
    num_input_distributions = 1
    num_sums = 1
    
    graph = Graph.random_binary_trees(num_var = num_var, depth = int(np.ceil(np.log(num_var))), 
                                      num_repetitions = num_repetitions)
    einet_args = EinsumNetwork.Args(num_var = num_var, num_dims=1, num_classes=1, 
                                    num_input_distributions=num_input_distributions, 
                                    num_sums=num_sums,
                                    exponential_family=EinsumNetwork.LogisticArray,
                                    exponential_family_args={'scale_min': 1e-4, 'scale_max': 1.0},
                                    online_em_frequency=1,
                                    online_em_stepsize=0.05)
    
    einet = EinsumNetwork.EinsumNetwork(graph, einet_args)
    einet.initialize()
    
    x = torch.rand([32, num_var])
    mean = torch.rand([32, num_var])
    logscale = torch.rand([32, num_var])
    
    x1 = Variable(x.detach().clone(), requires_grad = True)
    mean1 = Variable(mean.detach().clone(), requires_grad = True)
    logscale1 = Variable(logscale.detach().clone(), requires_grad = True)
    ll_sample = einet.forward_with_grad2(x1, mean1, logscale1)
    x1.retain_grad()
    mean1.retain_grad()
    logscale1.retain_grad()
    ll_sample.mean().backward()
    
    x2 = Variable(x.detach().clone(), requires_grad = True)
    mean2 = Variable(mean.detach().clone(), requires_grad = True)
    logscale2 = Variable(logscale.detach().clone(), requires_grad = True)
    ll = baseline(x2, mean2, logscale2)
    x2.retain_grad()
    mean2.retain_grad()
    logscale2.retain_grad()
    ll.mean().backward()
    
    print((x1.grad - x2.grad).sum())
    print((mean1.grad - mean2.grad).sum())
    print((logscale1.grad - logscale2.grad).sum())
    
    
def baseline(x, mean, logscale):
    
    def log_min_exp(a, b, epsilon=1e-8):
        y = a + torch.log(1 - torch.exp(b - a) + epsilon)
        return y
    
    def log_discretized_logistic(xlow, xhigh):
        logp = log_min_exp(
            F.logsigmoid(xhigh),
            F.logsigmoid(xlow))
        return logp
    
    x1 = (x - mean - 0.5 / 256.0) / torch.exp(logscale)
    x2 = (x - mean + 0.5 / 256.0) / torch.exp(logscale)
    # print("x1 ", x1.view(-1))
    # print("x2 ", x2.view(-1))
    # print("ab", x.mean(), mean.mean(), logscale.mean())
    return log_discretized_logistic(x1, x2).sum(dim = 1)


if __name__ == "__main__":
    main()