# -*- coding: utf-8 -*-
import numpy as np
import argparse
import torch as to
from typing import Tuple
import h5py

from tvem.exp import EEMConfig, FullEMConfig, ExpConfig, Training
from tvem.utils.param_init import init_sigma_default, init_W_data_mean
from tvem.models import TVAE


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("dataset", help="HD5 file as expected in input by tvem.Training")
    parser.add_argument("--Ksize", type=int, help="size of each K^n set")
    parser.add_argument("--epochs", type=int, default=50, help="number of training epochs")
    parser.add_argument(
        "--net-shape",
        required=True,
        type=parse_net_shape,
        help="column-separated list of layer sizes",
    )
    parser.add_argument("--min_lr", type=float, help="MLP min learning rate", required=True)
    parser.add_argument("--max_lr", type=float, help="MLP max learning rate", required=True)
    parser.add_argument("--batch-size", type=int, required=True)
    parser.add_argument("--output", help="output file for train log", required=True)
    parser.add_argument(
        "--seed", type=int, help="seed value for random number generators. default is a random seed"
    )
    return parser.parse_args()


def parse_net_shape(net_shape: str) -> Tuple[int, ...]:
    """
    Parse string with TVAE shape into a tuple.

    :param net_shape: column-separated list of integers, e.g. `"10:10:2"`
    :returns: a tuple with the shape as integers, e.g. `(10,10,2)`
    """
    return tuple(map(int, net_shape.split(":")))


if __name__ == "__main__":
    args = parse_args()

    # parameter initialization
    print("parameter initialization...", end="")
    if args.seed:
        to.manual_seed(args.seed)
        np.random.seed(args.seed)
    S = args.Ksize
    net_shape = args.net_shape
    H = net_shape[-1]
    data_fname = args.dataset
    data_file = h5py.File(data_fname, "r")
    N, D = data_file["data"].shape
    assert net_shape[0] == D

    print(f"\ninput file: {args.dataset}")
    print(f"true logL: {data_file['ground_truth']['logL'][...]}")
    print(f"net layers: {net_shape}")
    for var in "S", "H", "N", "D":
        print(f"{var} = {eval(var)}")

    gt = data_file["ground_truth"]
    data = data_file["data"][...]
    data_mean = to.mean(to.from_numpy(data), dim=0)
    data_var = to.var(to.from_numpy(data), dim=0)
    W_from_data = [to.eye(H).double(), init_W_data_mean(data_mean, data_var, H).t()]
    W_gt = [to.from_numpy(gt["W0"][...]), to.from_numpy(gt["W1"][...])]
    sigma2_gt = float(gt["sigma2"][...])  # init_sigma_default(data_var).pow_(2)
    pi_gt = to.from_numpy(data_file["ground_truth"]["pies"][...])
    m = TVAE(net_shape, min_lr=args.min_lr, max_lr=args.max_lr, cycliclr_step_size_up=400)#, W_init=W_gt, pi_init=pi_gt, sigma2_init=sigma2_gt)
    conf = ExpConfig(batch_size=args.batch_size, output=args.output)
    # estep_conf = EEMConfig(n_states=args.Ksize, n_parents=min(3,S), n_children=min(2,S), n_generations=1, crossover=False)
    estep_conf = FullEMConfig()
    t = Training(conf, estep_conf, m, data_fname)
    print("\nlearning...")
    for e_log in t.run(args.epochs):
        e_log.print()
