from tvem.models import TVAE
from tvem.exp import ExpConfig, EEMConfig, Training
import h5py
import torch as to
import numpy as np

import argparse

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", help="HD5 file as expected in input by tvem.Training")
    parser.add_argument("--bsc-train-out", help="HD5 file with output of BSC training")
    parser.add_argument("--epochs", type=int, default=50, help="number of training epochs")
    parser.add_argument("--batch-size", type=int, required=True)
    parser.add_argument("--output", help="output file for train log", required=True)
    return parser.parse_args()

def train(args):
    with h5py.File(args.bsc_train_out, "r") as bsc_out:
        W_bsc = to.from_numpy(bsc_out["theta"]["W"][-1]).t()
        H, D = W_bsc.shape
        pi_bsc = to.from_numpy(bsc_out["theta"]["pies"][-1])
        sigma2_bsc = to.from_numpy(bsc_out["theta"]["sigma"][-1]).pow(2)
        K = to.from_numpy(bsc_out["train_states"][...])
        S = K.shape[1]

    # setup TVAE
    N = K.shape[1]
    epochs_per_cycle = 10
    step = int(epochs_per_cycle * np.ceil(N / args.batch_size) // 2)
    tvae = TVAE(W_init=[to.eye(H), W_bsc], pi_init=pi_bsc, sigma2_init=sigma2_bsc,
                min_lr=0.0001, max_lr=0.001, cycliclr_step_size_up=step)

    # setup experiment
    estep_conf = EEMConfig(n_states=S, n_parents=min(8,S), n_children=min(7,S), n_generations=2, crossover=False)
    conf = ExpConfig(batch_size=args.batch_size, output=args.output)
    t = Training(conf, estep_conf, tvae, args.dataset)
    t.train_states.K[:] = K
    for e_log in t.run(args.epochs):
        e_log.print()

def main():
    args = parse_args()
    train(args)

if __name__ == "__main__":
    main()
