import tvem
from tvem.models import BSC
from tvem.exp import ExpConfig, Training, EEMConfig
import h5py
import torch as to

import argparse

def parse_args():
  parser = argparse.ArgumentParser()
  parser.add_argument("--dataset", help="HD5 file as expected in input by tvem.Training", required=True)
  parser.add_argument("--Ksize", type=int, help="size of each K^n set", required=True)
  parser.add_argument("--epochs", type=int, default=50, help="number of training epochs")
  parser.add_argument("--batch-size", type=int, required=True)
  parser.add_argument("--bsc-start", help="BSC starting parameters")
  parser.add_argument("--output", help="output file for train log", required=True)
  return parser.parse_args()

def train():
  args = parse_args()
  data_fname = args.dataset
  data_file = h5py.File(data_fname, "r")
  H = 300
  S = args.Ksize
  estep_conf = EEMConfig(n_states=args.Ksize, n_parents=min(8,S), n_children=min(7,S), n_generations=2, crossover=False)
  if args.bsc_start is not None:
    with h5py.File(args.bsc_start, "r") as bsc_start:
      W = to.from_numpy(bsc_start["theta"]["W"][-1])  # D,H
      pi = to.from_numpy(bsc_start["theta"]["pies"][-1])
      sigma = to.tensor([bsc_start["theta"]["sigma"][-1].item()])
      K = to.from_numpy(bsc_start["train_states"][...])
    m = BSC(H=H, D=256, W_init=W, pies_init=pi, sigma_init=sigma)
  else:
    m = BSC(H=H, D=256)
  conf = ExpConfig(batch_size=args.batch_size, output=args.output)
  t = Training(conf, estep_conf, m, data_fname)
  if args.bsc_start is not None:
    t.train_states.K[:] = K
    del K
  for e_log in t.run(args.epochs):
      e_log.print()

if __name__ == "__main__":
  train()
