from tvem.exp import FullEMConfig, ExpConfig, Training, EEMConfig
from tvem.models import BSC
from tvem.utils.param_init import init_W_data_mean
import h5py

import argparse
import torch as to

def get_eem_new_states(c: EEMConfig):
    if c.crossover:
        return c.n_parents * (c.n_parents - 1) * c.n_children * c.n_generations
    else:
        return c.n_parents * c.n_children * c.n_generations

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("dataset", help="HD5 file as expected in input by tvem.Training")
    parser.add_argument("--epochs", type=int, default=50, help="number of training epochs")
    parser.add_argument("--batch-size", type=int, required=True)
    parser.add_argument("--output", help="output file for train log", required=True)
    parser.add_argument("--H", type=int, required=True)
    parser.add_argument("--Ksize", type=int, required=True)
    return parser.parse_args()

def train():
    args = parse_args()
    H = args.H
    batch_size = args.batch_size
    data_fname = args.dataset
    data_file = h5py.File(data_fname, "r")

    data = data_file["data"][...]
    N, D = data.shape
    data_mean = to.mean(to.from_numpy(data), dim=0)
    data_var = to.var(to.from_numpy(data), dim=0)
    W_from_data = init_W_data_mean(data_mean, data_var, H)
    estep_conf = EEMConfig(n_states=args.Ksize, n_parents=8, n_children=7, n_generations=2, crossover=True)
    # estep_conf = FullEMConfig()
    bsc_conf = {
        "D": D,
        "H": H,
        "S": 2**H,
        "Snew": get_eem_new_states(estep_conf),
        "batch_size": batch_size,
        "precision": to.double,
    }
    model = BSC(bsc_conf, W_init=W_from_data)
    conf = ExpConfig(batch_size=batch_size, output=args.output)
    t = Training(conf, estep_conf, model, data_fname)
    print("\nlearning...")
    for e_log in t.run(args.epochs):
        e_log.print()

if __name__ == "__main__":
    train()
