import numpy as np
import torch
import pickle
import os
import time
from PIL import Image

import sys
sys.path.append("../EinsumNetworks/src")

from EinsumNetwork import Graph, EinsumNetwork
import eiutils
import datasets

from prep_einsum import *


def main_z(gpu_idx = 1):
    num_epochs = 100
    batch_size = 64
    online_em_frequency = 1
    online_em_stepsize = 0.05

    num_repetitions = 10
    num_sums = 8
    num_input_distributions = 8

    height = 4
    width = 4

    device = torch.device(f"cuda:{gpu_idx}" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cpu")

    ## Load data ##
    train_data = np.load("data/train_zs.npz")
    valid_data = np.load("data/valid_zs.npz")
    test_data = np.load("data/test_zs.npz")

    # Baselines 
    log_discretized_logistic_baseline(train_data)
    log_discretized_logistic_baseline(valid_data)
    log_discretized_logistic_baseline(test_data)
    
    train_datasets = (
        torch.from_numpy(train_data["zs"]).permute(0, 2, 3, 1),
        torch.from_numpy(train_data["ys1"]).permute(0, 2, 3, 1),
        torch.from_numpy(train_data["ys2"]).permute(0, 2, 3, 1)
    )
    valid_datasets = (
        torch.from_numpy(valid_data["zs"]).permute(0, 2, 3, 1),
        torch.from_numpy(valid_data["ys1"]).permute(0, 2, 3, 1),
        torch.from_numpy(valid_data["ys2"]).permute(0, 2, 3, 1)
    )
    test_datasets = (
        torch.from_numpy(test_data["zs"]).permute(0, 2, 3, 1),
        torch.from_numpy(test_data["ys1"]).permute(0, 2, 3, 1),
        torch.from_numpy(test_data["ys2"]).permute(0, 2, 3, 1)
    )

    train_stats_means = [
        torch.from_numpy(train_data["pzs_0"]).unsqueeze(0).permute(0, 2, 3, 1).reshape(1, -1).repeat(
            train_data["pys1_0"].shape[0], 1),
        torch.from_numpy(train_data["pys1_0"]).permute(0, 2, 3, 1).reshape(train_data["pys1_0"].shape[0], -1),
        torch.from_numpy(train_data["pys2_0"]).permute(0, 2, 3, 1).reshape(train_data["pys2_0"].shape[0], -1)
    ]
    valid_stats_means = [
        torch.from_numpy(valid_data["pzs_0"]).unsqueeze(0).permute(0, 2, 3, 1).reshape(1, -1).repeat(
            valid_data["pys1_0"].shape[0], 1),
        torch.from_numpy(valid_data["pys1_0"]).permute(0, 2, 3, 1).reshape(valid_data["pys1_0"].shape[0], -1),
        torch.from_numpy(valid_data["pys2_0"]).permute(0, 2, 3, 1).reshape(valid_data["pys2_0"].shape[0], -1)
    ]
    test_stats_means = [
        torch.from_numpy(test_data["pzs_0"]).unsqueeze(0).permute(0, 2, 3, 1).reshape(1, -1).repeat(
            test_data["pys1_0"].shape[0], 1),
        torch.from_numpy(test_data["pys1_0"]).permute(0, 2, 3, 1).reshape(test_data["pys1_0"].shape[0], -1),
        torch.from_numpy(test_data["pys2_0"]).permute(0, 2, 3, 1).reshape(test_data["pys2_0"].shape[0], -1)
    ]

    train_stats_scales = [
        torch.from_numpy(train_data["pzs_1"]).unsqueeze(0).permute(0, 2, 3, 1).reshape(1, -1).repeat(
            train_data["pys1_1"].shape[0], 1),
        torch.from_numpy(train_data["pys1_1"]).permute(0, 2, 3, 1).reshape(train_data["pys1_1"].shape[0], -1),
        torch.from_numpy(train_data["pys2_1"]).permute(0, 2, 3, 1).reshape(train_data["pys2_1"].shape[0], -1)
    ]
    valid_stats_scales = [
        torch.from_numpy(valid_data["pzs_1"]).unsqueeze(0).permute(0, 2, 3, 1).reshape(1, -1).repeat(
            valid_data["pys1_1"].shape[0], 1),
        torch.from_numpy(valid_data["pys1_1"]).permute(0, 2, 3, 1).reshape(valid_data["pys1_1"].shape[0], -1),
        torch.from_numpy(valid_data["pys2_1"]).permute(0, 2, 3, 1).reshape(valid_data["pys2_1"].shape[0], -1)
    ]
    test_stats_scales = [
        torch.from_numpy(test_data["pzs_1"]).unsqueeze(0).permute(0, 2, 3, 1).reshape(1, -1).repeat(
            test_data["pys1_1"].shape[0], 1),
        torch.from_numpy(test_data["pys1_1"]).permute(0, 2, 3, 1).reshape(test_data["pys1_1"].shape[0], -1),
        torch.from_numpy(test_data["pys2_1"]).permute(0, 2, 3, 1).reshape(test_data["pys2_1"].shape[0], -1)
    ]
    
    train_stats_scales[0] -= torch.max(train_stats_scales[0])
    train_stats_scales[1] -= torch.max(train_stats_scales[1])
    train_stats_scales[2] -= torch.max(train_stats_scales[2])
    
    valid_stats_scales[0] -= torch.max(valid_stats_scales[0])
    valid_stats_scales[1] -= torch.max(valid_stats_scales[1])
    valid_stats_scales[2] -= torch.max(valid_stats_scales[2])
    
    test_stats_scales[0] -= torch.max(test_stats_scales[0])
    test_stats_scales[1] -= torch.max(test_stats_scales[1])
    test_stats_scales[2] -= torch.max(test_stats_scales[2])
    
    def process_data(data):
        new_data = torch.zeros([data.size(0), height, width, 3])
        for h in range(height):
            for w in range(width):
                orig_h, orig_w = h // 4, w // 4
                c = (h % 4) * 4 + (w % 4)
                c_s, c_e = c * 3, c * 3 + 3
                new_data[:, h, w, :] = data[:, orig_h, orig_w, c_s:c_e]
                
        return new_data
    
    # train_data = process_data(train_data).view(-1, height * width, 3)
    # valid_data = process_data(valid_data).view(-1, height * width, 3)
    # test_data = process_data(test_data).view(-1, height * width, 3)
    train_datasets = [
        train_datasets[0].reshape(train_datasets[0].size(0), -1),
        train_datasets[1].reshape(train_datasets[1].size(0), -1),
        train_datasets[2].reshape(train_datasets[2].size(0), -1)
    ]
    valid_datasets = [
        valid_datasets[0].reshape(valid_datasets[0].size(0), -1),
        valid_datasets[1].reshape(valid_datasets[1].size(0), -1),
        valid_datasets[2].reshape(valid_datasets[2].size(0), -1)
    ]
    test_datasets = [
        test_datasets[0].reshape(test_datasets[0].size(0), -1),
        test_datasets[1].reshape(test_datasets[1].size(0), -1),
        test_datasets[2].reshape(test_datasets[2].size(0), -1)
    ]
    
    print("Generating graph...", end = "")
    t = time.time()
    graphs = [
        Graph.random_binary_trees(
            num_var = train_datasets[0].size(1), 
            depth = int(np.ceil(np.log(train_datasets[0].size(1)))), 
            num_repetitions = num_repetitions
        ),
        Graph.random_binary_trees(
            num_var = train_datasets[1].size(1), 
            depth = int(np.ceil(np.log(train_datasets[1].size(1)))), 
            num_repetitions = num_repetitions
        ),
        Graph.random_binary_trees(
            num_var = train_datasets[2].size(1), 
            depth = int(np.ceil(np.log(train_datasets[2].size(1)))), 
            num_repetitions = num_repetitions
        )
    ]
    # graph = Graph.MI_binary_trees(data = train_data, num_var = train_data.size(1), depth = 5, num_repetitions = 10)
    print("({:.2f})".format(time.time() - t))
    
    einets = []
    for idx in range(3):
        args = EinsumNetwork.Args(
            num_var=train_datasets[idx].size(1),
            num_dims=1,
            num_classes=1,
            num_sums=num_sums,
            num_input_distributions=num_input_distributions,
            exponential_family=EinsumNetwork.NormalArray,
            exponential_family_args={'min_var': 1e-6, 'max_var': 0.2},
            online_em_frequency=online_em_frequency,
            online_em_stepsize=online_em_stepsize
        )
    
        einet = EinsumNetwork.EinsumNetwork(graphs[idx], args)
        einet.initialize()
        einet.to(device)
        
        einets.append(einet)
    
    ## Train model ##
    for epoch_idx in range(num_epochs):
        tr_bpd, val_bpd, tt_bpd = 0.0, 0.0, 0.0
        for idx in range(3):
            einet = einets[idx]
            
            train_data = train_datasets[idx]
            train_mean = train_stats_means[idx]
            train_scale = train_stats_scales[idx]
            
            valid_data = valid_datasets[idx]
            valid_mean = valid_stats_means[idx]
            valid_scale = valid_stats_scales[idx]
            
            test_data = test_datasets[idx]
            test_mean = test_stats_means[idx]
            test_scale = test_stats_scales[idx]
            
            shuffled_batch = make_shuffled_batch(train_data.size(0), batch_size)
            for i, batch_idx in enumerate(shuffled_batch):
                batch_data = train_data[batch_idx, :].to(device)
                batch_mean = train_mean[batch_idx, :].to(device)
                batch_scale = train_scale[batch_idx, :].to(device)
                batch = (batch_data - batch_mean) / torch.exp(batch_scale)

                ll_sample = einet.forward(batch)
                log_likelihood = ll_sample.sum()
                log_likelihood.backward()
                einet.em_process_batch()

            einet.em_update()
        
            ## Evaluate ##
            train_ll = eval_ll(einet, train_data, train_mean, train_scale, batch_size, device)
            valid_ll = eval_ll(einet, valid_data, valid_mean, valid_scale, batch_size, device)
            test_ll = eval_ll(einet, test_data, test_mean, test_scale, batch_size, device)
            train_bpd = -train_ll * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
            valid_bpd = -valid_ll * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
            test_bpd = -test_ll * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
            
            tr_bpd += train_bpd
            val_bpd += valid_bpd
            tt_bpd += test_bpd

            # print("[{}]   train LL {:.2f}   valid LL {:.2f}   test LL {:.2f}".format(epoch_idx, train_ll, valid_ll, test_ll))
            print("[{}]   train bpd {:.4f}   valid bpd {:.4f}   test bpd {:.4f}".format(epoch_idx, train_bpd, valid_bpd, test_bpd))
            
        print("total train bpd {:.4f}   total valid bpd {:.4f}   total test bpd {:.4f}".format(tr_bpd, val_bpd, tt_bpd))
            
    
    
def eval_ll(einet, data, data_mean, data_scale, batch_size, device):
    with torch.no_grad():
        shuffled_batch = make_shuffled_batch(data.size(0), batch_size)
        ll = 0.0
        for batch_idx in shuffled_batch:
            batch_data = data[batch_idx, :].to(device)
            batch_mean = data_mean[batch_idx, :].to(device)
            batch_scale = data_scale[batch_idx, :].to(device)
            
            xmin = (batch_data - 0.5 / 256.0 - batch_mean) / torch.exp(batch_scale)
            xmax = (batch_data + 0.5 / 256.0 - batch_mean) / torch.exp(batch_scale)
            
            ll_sample = einet.forward3(xmin, xmax)
            ll += ll_sample.sum()
            
        return ll / data.size(0)
    
    
def log_discretized_logistic_baseline(data):
    import torch.nn.functional as F
    
    def log_min_exp(a, b, epsilon=1e-8):
        y = a + torch.log(1 - torch.exp(b - a) + epsilon)
        return y
    
    def log_discretized_logistic(x, mean, logscale, inverse_bin_width=256.0):
        scale = torch.exp(logscale)
        logp = log_min_exp(
            F.logsigmoid((x + 0.5 / inverse_bin_width - mean) / scale),
            F.logsigmoid((x - 0.5 / inverse_bin_width - mean) / scale))
        return logp
    
    zs = torch.from_numpy(data["zs"])
    pzs = [torch.from_numpy(data["pzs_0"]), torch.from_numpy(data["pzs_1"])]
    log_pzs = log_discretized_logistic(zs, pzs[0].unsqueeze(0), pzs[1].unsqueeze(0)).mean(dim = 0).sum()
    bpd_pzs = -log_pzs * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
    
    ys1 = torch.from_numpy(data["ys1"])
    pys1 = [torch.from_numpy(data["pys1_0"]), torch.from_numpy(data["pys1_1"])]
    log_pys1 = log_discretized_logistic(ys1, pys1[0], pys1[1]).mean(dim = 0).sum()
    bpd_pys1 = -log_pys1 * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
    
    ys2 = torch.from_numpy(data["ys2"])
    pys2 = [torch.from_numpy(data["pys2_0"]), torch.from_numpy(data["pys2_1"])]
    log_pys2 = log_discretized_logistic(ys2, pys2[0], pys2[1]).mean(dim = 0).sum()
    bpd_pys2 = -log_pys2 * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
    
    bpd = bpd_pzs + bpd_pys1 + bpd_pys2
    print(f"log_discretized_logistic_baseline bpd: z - {bpd_pzs}, ys1 - {bpd_pys1}, ys2 - {bpd_pys2}, total: {bpd}")
    
    
if __name__ == "__main__":
    main_z()