# -*- coding: utf-8 -*-
import argparse
import h5py
from contextlib import closing
import torch as to

from tvem.utils import get
from tvem.models import BSC
from tvem.variational import FullEM


def bars_from_BSC(N: int, H: int):
    """Sample bar datapoints from a BSC model.

    :param N: Number of datapoints
    :param H: Number of latent variables

    :returns: (data, labels) tuple
    """
    R = H // 2
    D = R ** 2
    bar_amp = 1.0
    precision: to.dtype = to.float64
    bg_amp = 0.0

    W = bg_amp * to.ones((R, R, H), dtype=precision)
    for i in range(R):
        W[i, :, i] = bar_amp
        W[:, i, R + i] = bar_amp

    neg_amp = False  # Whether to set probability of amplitudes taking negative values to 50 percent
    if neg_amp:
        sign = 1 - 2 * to.randint(high=2, size=(H))
        W = sign[None, None, :] * W

    W_gt = W.view((D, H))
    sigma_gt = to.tensor([0.1], dtype=precision)
    pies_gt = to.full((H,), 2.0 / H, dtype=precision)

    model = BSC(H, D, W_gt, sigma_gt, pies_gt)

    data, labels = get(model.generate_data(N), "data", "hidden_state")
    states = FullEM(N, H, precision=model.precision)
    model.init_storage(S=2**H, Snew=0, batch_size=N)
    model.init_epoch()
    model.init_batch()
    states.update(to.arange(N), data, model)
    logL = model.free_energy(to.arange(N), data, states) / N
    print(f"logL: {logL}")
    ground_truth = {"N": N, "H": H, "pies": pies_gt, "W": W_gt, "sigma2": sigma_gt.pow(2), "logL": logL}
    return data, labels, logL, ground_truth


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-N", type=int, help="number of datapoints", required=True)
    parser.add_argument("-H", type=int, help="number of components. D=(H/2)**2", required=True)
    parser.add_argument("--output-dir", required=True)
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()

    N, H = args.N, args.H
    out_dir = args.output_dir

    data, labels, logL, ground_truth = bars_from_BSC(N, H)
    with closing(h5py.File(f"{out_dir}/H{H}_N{N}.h5", "w")) as f:
        f.create_dataset("data", data=data.cpu())
        f.create_dataset("labels", data=labels.cpu())
        g = f.create_group("ground_truth")
        for k, v in ground_truth.items():
            g.create_dataset(k, data=v)
