from typing import Any, Dict, List, Tuple

import hydra
import rootutils
import lightning as L
from lightning import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
import torch.distributed as dist

rootutils.setup_root(
    __file__,
    indicator=[".git", "pyproject.toml"],
    pythonpath=True
)

from src.utils import (
    RankedLogger,
    extras,
    instantiate_loggers,
    log_hyperparameters,
    task_wrapper,
)

log = RankedLogger(__name__, rank_zero_only=True)




@task_wrapper
def test(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """Evaluates given checkpoint on a datamodule testset.

    This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
    failure. Useful for multiruns, saving info about the crash, etc.

    :param cfg: DictConfig configuration composed by Hydra.
    :return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
    """
    assert cfg.ckpt_path
    if cfg.get("seed"):
        L.seed_everything(cfg.seed, workers=True)
        
    from datetime import timedelta
    dist.init_process_group(
        backend='nccl',
        timeout=timedelta(hours=48)
    )
    for split_name in ["train", "val", "test"]:
        log.info(f"Instantiating model <{cfg.model._target_}>")
        model: LightningModule = hydra.utils.instantiate(cfg.model)
        model.gen_sample_json_name = f'generated_samples_{split_name}.json'
        model.gen_metric_json_name = f'generated_metrics_{split_name}.json'

        log.info("Instantiating loggers...")
        logger: List[Logger] = instantiate_loggers(cfg.get("logger"))

        log.info(f"Instantiating datamodule <{cfg.data._target_}>")
        datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
        from src.data.graph_order.module import GraphOrder
        datamodule: GraphOrder
        
        log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
        trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
        
        log.info(f"Starting testing on {split_name} dataset!")

        datamodule.test_split_name = split_name
        
        object_dict = {
            "cfg": cfg,
            "datamodule": datamodule,
            "model": model,
            "logger": logger,
            "trainer": trainer,
        }

        if logger:
            log.info("Logging hyperparameters!")
            log_hyperparameters(object_dict)
        
        log.info("Starting testing!")
        trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)

    return {}, object_dict

@hydra.main(version_base="1.3", config_path="../configs", config_name="test.yaml")
def main(cfg: DictConfig) -> None:
    """Main entry point for evaluation.

    :param cfg: DictConfig configuration composed by Hydra.
    """
    # apply extra utilities
    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
    extras(cfg)
    test(cfg)





if __name__ == "__main__":


    main()
