import os

import hydra
import yaml
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

import smlm
from smlm import utils


@hydra.main(version_base=None, config_path="../config")
@utils.oom.handle_oom()
def train(cfg: DictConfig):
    """
    Train a neural network
    """
    utils.torch.initialize_torch()
    fabric = utils.fabric.initialize_fabric(cfg.seed)
    cfg = utils.config.initialize_config(cfg)

    cfg = utils.config.add_git_commit_hash(cfg)
    cfg = utils.config.add_eff_batch_size(cfg, world_size=fabric.world_size)
    if fabric.is_global_zero:
        print(OmegaConf.to_yaml(cfg, sort_keys=True))

    # data
    ds_train = instantiate(cfg.ds_train)
    dl_train = utils.dataloader.build_from_config(fabric=fabric, cfg=cfg, ds=ds_train)
    ds_val = instantiate(cfg.ds_val)
    dl_val = utils.dataloader.build_from_config(fabric=fabric, cfg=cfg, ds=ds_val)
    cfg = utils.config.add_total_steps(cfg, step_per_epoch=len(ds_train))

    # model
    with fabric.init_module(empty_init="ckpt_path" in cfg):
        model = instantiate(cfg.model)
        training_module = instantiate(cfg.trainer, model=model)
        opt = instantiate(cfg.optimizer, params=training_module.parameters())
        scheduler = instantiate(cfg.scheduler, optimizer=opt)
    if fabric.is_global_zero:
        utils.model.present_model(model)
    if cfg.compile:
        training_module.compile()
        if fabric.is_global_zero:
            print("Compilation enabled.")

    epoch, step = 0, 0
    if "ckpt_path" in cfg:
        epoch, step = utils.checkpoint.load_training(
            fabric=fabric,
            ckpt_path=cfg.ckpt_path,
            optimizer=opt,
            scheduler=scheduler,
            training_module=training_module,
        )
        epoch += 1
        if fabric.is_global_zero:
            print(f"Training checkpoint loaded from {cfg.ckpt_path}")
    elif "weights_path" in cfg:
        utils.checkpoint.load_weights(
            fabric=fabric, ckpt_path=cfg.weights_path, model=model
        )
        if fabric.is_global_zero:
            print(f"Weights loaded from {cfg.weights_path}")

    # data for calibration if needed
    model_needs_calib = hasattr(model, "apply_thresholds")
    if model_needs_calib:
        ds_calib = instantiate(cfg.ds_calib)
        dl_calib = utils.dataloader.build_from_config(
            fabric=fabric, cfg=cfg, ds=ds_calib
        )
    else:
        dl_calib = None

    # logs
    if "log_dir" not in cfg:
        logdir = None
        if fabric.is_global_zero:
            logdir = utils.logs.get_log_dir(cfg.name)
        logdir = fabric.broadcast(logdir, src=0)
        cfg.log_dir = logdir
    if fabric.is_global_zero:
        print(f"Log directory is {cfg.log_dir}")

    def first_epoch_done_callback():
        if fabric.is_global_zero:
            cfg_path = utils.logs.get_config_path(cfg.log_dir)
            os.makedirs(os.path.dirname(cfg_path), exist_ok=True)
            with open(cfg_path, "w") as f:
                yaml.dump(OmegaConf.to_container(cfg), f, default_flow_style=False)

    # acceleration
    training_module, opt = fabric.setup(training_module, opt)
    if model_needs_calib:
        dl_train, dl_calib, dl_val = fabric.setup_dataloaders(
            dl_train, dl_calib, dl_val
        )
    else:
        dl_train, dl_val = fabric.setup_dataloaders(dl_train, dl_val)

    # train
    best_metric = smlm.loops.training_step_loop(
        fabric=fabric,
        opt=opt,
        scheduler=scheduler,
        model=model,
        training_module=training_module,
        dl_train=dl_train,
        dl_calib=dl_calib,
        dl_val=dl_val,
        log_dir=cfg.log_dir,
        begin_epoch=epoch,
        begin_step=step,
        n_epochs=cfg.n_epochs,
        n_accum_steps=cfg.n_accum_steps,
        patience=cfg.patience,
        watched_metric=cfg.watched_metric,
        first_epoch_done_callback=first_epoch_done_callback,
    )
    return best_metric


if __name__ == "__main__":
    train()
