from typing import Any, Dict, List, Optional, Tuple
from hydra.core.hydra_config import HydraConfig

import hydra
import lightning as L
import rootutils
import torch
import wandb
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig, OmegaConf

OmegaConf.register_new_resolver("eval", eval)

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
#       (so you don't need to force user to install project as a package)
#       (necessary before importing any local modules e.g. `from src import utils`)
# - setting up PROJECT_ROOT environment variable
#       (which is used as a base for paths in "configs/paths/default.yaml")
#       (this way all filepaths are the same no matter where you run the code)
# - loading environment variables from ".env" in root dir
#
# you can remove it if you:
# 1. either install project as a package or move entry files to project root dir
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
#
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #

from src.utils.hydra_cli_overrides import compare_dicts, overrides_as_str
from src.utils.pylogger import RankedLogger
from src.utils.utils import extras, get_metric_value, load_ckpt_path, load_class, task_wrapper
from src.utils.instantiators import instantiate_callbacks, instantiate_loggers
from src.utils.logging_utils import load_run_config_from_wb, log_hyperparameters


log = RankedLogger(__name__, rank_zero_only=True)


def recursive_dict_update(original_dict, update_dict):
    for key, value in update_dict.items():
        if key in original_dict and isinstance(original_dict[key], dict) and isinstance(value, dict):
            recursive_dict_update(original_dict[key], value)
        else:
            OmegaConf.update(original_dict, key, value, force_add=True)
    return original_dict

@task_wrapper
def train(cfg: DictConfig, cli_overrides: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
    training.

    This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
    failure. Useful for multiruns, saving info about the crash, etc.

    :param cfg: A DictConfig configuration composed by Hydra.
    :return: A tuple with metrics and dict with all instantiated objects.
    """
    # set seed for random number generators in pytorch, numpy and python.random
    if cfg.get("seed"):
        L.seed_everything(cfg.seed, workers=True)
        
        
    wandb_cfg = load_run_config_from_wb(**cfg.wandb)    
    # only overwrite model
    cfg.model.model = wandb_cfg.model.model
    # the following don't work so do a dirty hack
    # cfg.model.model.update(cfg.model_overrides)
    # OmegaConf.update(cfg.model.model, cfg.model_overrides, merge=True, force_add=True)
    recursive_dict_update(cfg.model.model, cfg.model_overrides)


    log.info(f"Instantiating datamodule <{cfg.data._target_}>")
    datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)

    # log.info(f"Instantiating model <{cfg.model._target_}>")
    # model: LightningModule = hydra.utils.instantiate(cfg.model)

    log.info(f"Instantiate and load checkpoint <{cfg.model._target_}>")
    checkpoint = load_ckpt_path(wandb_cfg.callbacks.model_checkpoint.dirpath, last=False)
    model_class = load_class(cfg.model._target_)
    loss_fn = hydra.utils.instantiate(cfg.model.loss_function)
    model_instance = hydra.utils.instantiate(cfg.model.model)
    model = model_class.load_from_checkpoint(
        checkpoint,
        model=model_instance,
        loss_function=loss_fn,
        # map_location="cpu",
    )

    log.info("Instantiating callbacks...")
    callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))

    log.info("Instantiating loggers...")
    loggers: List[Logger] = instantiate_loggers(cfg.get("logger"))

    # TODO: put this in a better place, and also check if type is wandblogger.
    if loggers and cfg.get("log_grads"):
        loggers[0].watch(model)

    log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
    trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=loggers)
    
    object_dict = {
        "cfg": cfg,
        "datamodule": datamodule,
        "model": model,
        "callbacks": callbacks,
        "logger": loggers,
        "trainer": trainer,
        'CLI overrides': cli_overrides
    }

    if loggers:
        log.info("Logging hyperparameters!")
        log_hyperparameters(object_dict)

    log.info("Starting training!")
    trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))

    train_metrics = trainer.callback_metrics

    if cfg.get("test"):
        log.info("Starting testing!")
        ckpt_path = trainer.checkpoint_callback.best_model_path
        if ckpt_path == "":
            log.warning("Best ckpt not found! Using current weights for testing...")
            ckpt_path = None
        trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
        log.info(f"Best ckpt path: {ckpt_path}")

    test_metrics = trainer.callback_metrics

    # merge train and test metrics
    metric_dict = {**train_metrics, **test_metrics}

    return metric_dict, object_dict



@hydra.main(version_base="1.3", config_path="../configs")
# @hydra.main(version_base="1.3", config_path="../configs", config_name="train_2d.yaml")
def main(cfg: DictConfig) -> Optional[float]:
    """Main entry point for training.

    :param cfg: DictConfig configuration composed by Hydra.
    :return: Optional[float] with optimized metric value.
    """
    
    # apply extra utilities
    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
    extras(cfg)

    # train the model
    overrides = " ".join(HydraConfig.get().overrides.task)
    metric_dict, _ = train(cfg, overrides)

    # safely retrieve metric value for hydra-based hyperparameter optimization
    metric_value = get_metric_value(
        metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
    )

    # return optimized metric
    return metric_value


if __name__ == "__main__":
    main()

