import numpy as np
import torch
import pickle
import os
import time
from PIL import Image

import sys
sys.path.append("../EinsumNetworks/src")

from EinsumNetwork import Graph, EinsumNetwork
import eiutils

from prep_einsum import *


def main_z(gpu_idx = 1):
    num_epochs = 100
    batch_size = 64
    online_em_frequency = 1
    online_em_stepsize = 0.05
    
    height = 4
    width = 4
    
    device = torch.device(f"cuda:{gpu_idx}" if torch.cuda.is_available() else "cpu")
    
    ## Load data ##
    train_data = np.load("data/train_zs.npz")
    valid_data = np.load("data/valid_zs.npz")
    test_data = np.load("data/test_zs.npz")
    
    # Baselines 
    log_discretized_logistic_baseline(train_data)
    log_discretized_logistic_baseline(valid_data)
    log_discretized_logistic_baseline(test_data)
    
    train_data = torch.from_numpy(train_data["zs"]).permute(0, 2, 3, 1)
    valid_data = torch.from_numpy(valid_data["zs"]).permute(0, 2, 3, 1)
    test_data = torch.from_numpy(test_data["zs"]).permute(0, 2, 3, 1)
    
    def process_data(data):
        new_data = torch.zeros([data.size(0), height, width, 3])
        for h in range(height):
            for w in range(width):
                orig_h, orig_w = h // 4, w // 4
                c = (h % 4) * 4 + (w % 4)
                c_s, c_e = c * 3, c * 3 + 3
                new_data[:, h, w, :] = data[:, orig_h, orig_w, c_s:c_e]
                
        return new_data
    
    # train_data = process_data(train_data).view(-1, height * width, 3)
    # valid_data = process_data(valid_data).view(-1, height * width, 3)
    # test_data = process_data(test_data).view(-1, height * width, 3)
    train_data = train_data.reshape(-1, height * width * 48)
    valid_data = valid_data.reshape(-1, height * width * 48)
    test_data = test_data.reshape(-1, height * width * 48)
    
    print("Generating graph...")
    t = time.time()
    graph = Graph.random_binary_trees(num_var = train_data.size(1), depth = 5, num_repetitions = 10)
    # graph = Graph.MI_binary_trees(data = train_data, num_var = train_data.size(1), depth = 5, num_repetitions = 10)
    print("({:.2f})".format(time.time() - t))
    
    args = EinsumNetwork.Args(
        num_var=train_data.size(1),
        num_dims=1,
        num_classes=1,
        num_sums=8,
        num_input_distributions=8,
        exponential_family=EinsumNetwork.NormalArray,
        exponential_family_args={'min_var': 1e-6, 'max_var': 1.0},
        online_em_frequency=online_em_frequency,
        online_em_stepsize=online_em_stepsize)
    
    einet = EinsumNetwork.EinsumNetwork(graph, args)
    einet.initialize()
    einet.to(device)
    
    ## Train model ##
    for epoch_idx in range(num_epochs):
        shuffled_batch = make_shuffled_batch(train_data.size(0), batch_size)
        for i, batch_idx in enumerate(shuffled_batch):
            batch = train_data[batch_idx, :].to(device)
            
            ll_sample = einet.forward(batch)
            log_likelihood = ll_sample.sum()
            log_likelihood.backward()
            einet.em_process_batch()
            
        einet.em_update()
        
        ## Evaluate ##
        train_ll = eval_ll(einet, train_data, batch_size, device)
        valid_ll = eval_ll(einet, valid_data, batch_size, device)
        test_ll = eval_ll(einet, test_data, batch_size, device)
        train_bpd = -train_ll * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
        valid_bpd = -valid_ll * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
        test_bpd = -test_ll * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
        
        print("[{}]   train LL {:.2f}   valid LL {:.2f}   test LL {:.2f}".format(epoch_idx, train_ll, valid_ll, test_ll))
        print("[{}]   train bpd {:.4f}   valid bpd {:.4f}   test bpd {:.4f}".format(epoch_idx, train_bpd, valid_bpd, test_bpd))
    
    
def eval_ll(einet, data, batch_size, device):
    with torch.no_grad():
        shuffled_batch = make_shuffled_batch(data.size(0), batch_size)
        ll = 0.0
        for batch_idx in shuffled_batch:
            batch = data[batch_idx, :].to(device)
            
            ll_sample = einet.forward2(batch)
            ll += ll_sample.sum()
            
        return ll / data.size(0)
    
    
def log_discretized_logistic_baseline(data):
    import torch.nn.functional as F
    
    def log_min_exp(a, b, epsilon=1e-8):
        y = a + torch.log(1 - torch.exp(b - a) + epsilon)
        return y
    
    def log_discretized_logistic(x, mean, logscale, inverse_bin_width=256.0):
        scale = torch.exp(logscale)
        logp = log_min_exp(
            F.logsigmoid((x + 0.5 / inverse_bin_width - mean) / scale),
            F.logsigmoid((x - 0.5 / inverse_bin_width - mean) / scale))
        return logp
    
    zs = torch.from_numpy(data["zs"])
    pzs = [torch.from_numpy(data["pzs_0"]), torch.from_numpy(data["pzs_1"])]
    log_pzs = log_discretized_logistic(zs, pzs[0].unsqueeze(0), pzs[1].unsqueeze(0)).mean(dim = 0).sum()
    bpd_pzs = -log_pzs * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
    print(f"log_discretized_logistic_baseline: {log_pzs}; bpd: {bpd_pzs}")
    
    
if __name__ == "__main__":
    main_z()