# -*- coding: utf-8 -*-
import argparse
import h5py
from contextlib import closing
import torch as to

from tvem.utils import get
from tvem.models import TVAE
from tvem.variational import FullEM

def add_simple_correlation(W0):
    H = W0.shape[0]
    for h in range(H):
        if h == H - 1:
            W0[h, 0] = 1
        else:
            W0[h, h + 1] = 1

def add_complex_correlation(W0):
    W0[0, 1] = -2
    W0[-1, 0] = -2

def bars_from_TVAE(N: int, H: int):
    """Sample bar datapoints from a BSC model.

    :param N: Number of datapoints
    :param H: Number of latent variables

    :returns: (data, labels) tuple
    """
    R = H // 2
    D = R ** 2
    bar_amp = 1.0
    precision: to.dtype = to.float64
    bg_amp = 0.0

    W0 = to.eye(H)
    add_complex_correlation(W0)

    W = bg_amp * to.ones((R, R, H), dtype=precision)
    for i in range(R):
        W[i, :, i] = bar_amp
        W[:, i, R + i] = bar_amp

    W_gt = W.view((D, H)).t()
    sigma2_gt = to.tensor([0.01], dtype=precision)
    pies_gt = to.full((H,), 2.0 / H, dtype=precision)

    model = TVAE(shape=[D,H,H], W_init=[W0, W_gt], sigma2_init=sigma2_gt, pi_init=pies_gt)

    data, labels = get(model.generate_data(N), "data", "hidden_state")
    states = FullEM(N, H, precision=model.precision)
    model.init_epoch()
    model.init_batch()
    states.update(to.arange(N), data, model)
    logL = model.free_energy(to.arange(N), data, states) / N
    print(f"logL: {logL}")
    ground_truth = {"N": N, "H": H, "pies": pies_gt, "W0": W0, "W1": W_gt, "sigma2": sigma2_gt, "logL": logL}
    return data, labels, logL, ground_truth


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-N", type=int, help="number of datapoints", required=True)
    parser.add_argument("-H", type=int, help="number of components. D=(H/2)**2", required=True)
    parser.add_argument("--output-dir", required=True)
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()

    N, H = args.N, args.H
    out_dir = args.output_dir

    data, labels, logL, ground_truth = bars_from_TVAE(N, H)
    with closing(h5py.File(f"{out_dir}/H{H}_N{N}.h5", "w")) as f:
        f.create_dataset("data", data=data.cpu())
        f.create_dataset("labels", data=labels.cpu())
        g = f.create_group("ground_truth")
        for k, v in ground_truth.items():
            g.create_dataset(k, data=v)
