# -*- coding: utf-8 -*-

import torch as to
from tvem.models import TVAE
from tvem.exp import EEMConfig
from tvem.utils.parallel import broadcast, barrier
from exp import ImageReconstruction
from utils import (
    init_processes_and_get_rank,
    get_dataset_file_name,
    read_image_set_to_nan_get_patches_and_write_out,
    get_cycliclr_half_step_size,
)
from params import (
    get_exp_parser,
    get_tvae_network_parser,
    get_learning_rate_scheduler_parser,
    get_eem_parser,
    make_parser_and_parse_args,
    defaults,
)


if __name__ == "__main__":

    comm_rank = init_processes_and_get_rank()

    args = make_parser_and_parse_args(
        [
            get_exp_parser(),
            get_tvae_network_parser(),
            get_learning_rate_scheduler_parser(),
            get_eem_parser(),
        ]
    )

    dataset_file = get_dataset_file_name(args.output_directory)

    patches, eval_metric_fn, eval_metric_name, eval_metric_label, reco_logger = (
        read_image_set_to_nan_get_patches_and_write_out(
            image_file=args.image_file,
            patch_height=args.patch_size[0],
            patch_width=args.patch_size[1],
            patch_shift=defaults.patch_shift,
            precision=defaults.precision,
            dataset_file=dataset_file,
            incomplete_percentage=args.incomplete_percentage,
        )
        if comm_rank == 0
        else (None, None, None, None, None)
    )
    barrier()

    no_data_points = (
        to.tensor([patches.get_number_of_patches()]) if comm_rank == 0 else to.tensor([0])
    )
    broadcast(no_data_points)
    no_channels = (
        (to.tensor([patches.get_image_shape()[-1]] if len(patches.get_image_shape()) == 3 else 1))
        if comm_rank == 0
        else to.tensor([0])
    )
    broadcast(no_channels)
    cycliclr_half_step_size = get_cycliclr_half_step_size(
        no_data_points=no_data_points.item(),
        batch_size=args.batch_size,
        epochs_per_half_cycle=args.epochs_per_half_cycle,
        no_epochs=args.no_epochs,
    )
    net_shape = (
        (args.patch_size[0] * args.patch_size[1] * no_channels.item(),)
        + tuple(args.inner_net_shape)
        + (args.H,)
    )
    model = TVAE(
        shape=net_shape,
        min_lr=args.min_lr,
        max_lr=args.max_lr,
        cycliclr_step_size_up=cycliclr_half_step_size,
        precision=defaults.precision,
    )

    estep_conf = EEMConfig(
        n_states=args.Ksize,
        n_parents=args.no_parents,
        n_children=args.no_children,
        n_generations=args.no_generations,
        parent_selection=args.parent_selection,
        crossover=args.crossover,
    )

    image_reconstruction = ImageReconstruction(
        data_file=dataset_file,
        patches=patches,
        model=model,
        estep_conf=estep_conf,
        output_directory=args.output_directory,
        batch_size=args.batch_size,
        no_epochs=args.no_epochs,
        merge_every=args.merge_every,
        eval_metric_fn=eval_metric_fn,
        eval_metric_name=eval_metric_name,
        eval_metric_label=eval_metric_label,
        interactive=args.interactive,
        interactive_pause=defaults.interactive_pause,
        reco_logger=reco_logger,
        stop_if_eval_metric_diff_negative_in_x_of_y_epochs=args.stop_if_eval_metric_decreases_in_x_of_y_epochs,  # noqa
        keep_training_data_file=args.keep_training_data_file,
        keep_training_output_file=args.keep_training_output_file,
        keep_reco_file=args.keep_reco_file,
        warmup_Esteps=defaults.warmup_Esteps,
    )

    image_reconstruction.run()
