# -*- coding: utf-8 -*-
import numpy as np
import argparse
import torch as to
from typing import Tuple
import h5py
import math
from pathlib import Path
import os
import re
from depatchify import depatchify

import tvem
from tvem.exp import EEMConfig, ExpConfig, Training
from tvem.models import TVAE
from tvem.utils.parallel import pprint, init_processes, all_reduce, gather_from_processes, barrier
import torch.distributed as dist


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("dataset", help="HD5 file as expected in input by tvem.Training")
    parser.add_argument("--Ksize", type=int, help="size of each K^n set")
    parser.add_argument("--epochs", type=int, default=50, help="number of training epochs")
    parser.add_argument(
        "--net-shape",
        required=True,
        type=parse_net_shape,
        help="column-separated list of layer sizes",
    )
    parser.add_argument("--min-lr", type=float, help="MLP min learning rate", required=True)
    parser.add_argument("--max-lr", type=float, help="MLP max learning rate", required=True)
    parser.add_argument(
        "--epochs-per-cycle", type=int, help="Number of epochs for every cyclic learning rate cycle", required=True
    )
    parser.add_argument("--n-parents", type=int, help="EEM parents per generation", required=True)
    parser.add_argument("--n-children", type=int, help="EEM children per generation", required=True)
    parser.add_argument("--n-generations", type=int, help="EEM generations", required=True)
    parser.add_argument(
        "--crossover", type=int, help="0/1 whether to use crossover in EEM", required=True
    )
    parser.add_argument(
        "--analytical-pi",
        type=int,
        help="0/1 whether to use analytical equations for pi updates",
        required=True,
    )
    parser.add_argument(
        "--analytical-sigma",
        type=int,
        help="0/1 whether to use analytical equations for sigma updates",
        required=True,
    )
    parser.add_argument("--batch-size", type=int, required=True)
    parser.add_argument("--output", help="output file for train log", required=True)
    parser.add_argument(
        "--lockfile",
        default="__train_lockfile__",
        help="name of file that different processes can use as an inter-process lock",
    )
    return parser.parse_args()


def parse_net_shape(net_shape: str) -> Tuple[int, ...]:
    """
    Parse string with TVAE shape into a tuple.

    :param net_shape: column-separated list of integers, e.g. `"10:10:2"`
    :returns: a tuple with the shape as integers, e.g. `(10,10,2)`
    """
    return tuple(map(int, net_shape.split(":")))


def train(
    dataset,
    Ksize,
    epochs,
    net_shape,
    min_lr,
    max_lr,
    epochs_per_cycle,
    n_parents,
    n_children,
    n_generations,
    crossover,
    analytical_pi,
    analytical_sigma,
    batch_size,
    output,
    lockfile,
):
    print("Run policy: ", tvem.get_run_policy())
    if tvem.get_run_policy() == "mpi":
        init_processes()

    # parameter initialization
    pprint("parameter initialization...", end="")
    S = Ksize
    H = net_shape[-1]
    data_fname = dataset
    data_file = h5py.File(data_fname, "r")
    N, D = data_file["data"].shape
    assert net_shape[0] == D

    pprint(f"\ninput file: {dataset}")
    pprint(f"net layers: {net_shape}")
    for var in "S", "H", "N", "D":
        pprint(f"{var} = {eval(var)}")

    # convert epochs_per_cycle to number of (batch) iterations per half cycle
    cycliclr_step_size_up = int(epochs_per_cycle * np.ceil(N / batch_size) // 2)

    m = TVAE(
        net_shape,
        min_lr=min_lr,
        max_lr=max_lr,
        cycliclr_step_size_up=cycliclr_step_size_up,
        precision=to.float32,
        analytical_sigma_updates=True if analytical_sigma == 1 else False,
        analytical_pi_updates=True if analytical_pi == 1 else False,
    )
    conf = ExpConfig(batch_size=batch_size, output=output, keep_best_states=True, eval_F_at_epoch_end=True)
    estep_conf = EEMConfig(
        n_states=Ksize,
        n_parents=min(n_parents, S),
        n_children=min(n_children, S) if crossover == 0 else None,
        n_generations=n_generations,
        crossover=True if crossover == 1 else False,
    )
    t = Training(conf, estep_conf, m, data_fname)

    out_dir = str(Path(output).parent)
    rank = dist.get_rank() if dist.is_initialized() else 0
    print("\nlearning...")
    for e_log in t.run(epochs):
        e_log.print()


if __name__ == "__main__":
    args = parse_args()
    train(
        args.dataset,
        args.Ksize,
        args.epochs,
        args.net_shape,
        args.min_lr,
        args.max_lr,
        args.epochs_per_cycle,
        args.n_parents,
        args.n_children,
        args.n_generations,
        args.crossover,
        args.analytical_pi,
        args.analytical_sigma,
        args.batch_size,
        args.output,
        args.lockfile,
    )
