from pytorch_lightning.utilities import rank_zero_only
from typing import List, Optional, Tuple, Dict
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import logging
from omegaconf import OmegaConf, DictConfig
import yaml
import os
import submitit
import git
from pathlib import Path


def setup_wandb(
    config: DictConfig,
    log: logging.Logger,
    git_hash: str = "",
    extra_configs: dict = dict(),
) -> WandbLogger:
    log_job_info(log)
    config_dict = yaml.safe_load(OmegaConf.to_yaml(config, resolve=True))
    job_logs_dir = os.getcwd()
    # increase timeout per wandb folks' suggestion
    os.environ["WANDB_INIT_TIMEOUT"] = "60"
    config_dict["job_logs_dir"] = job_logs_dir
    config_dict["git_hash"] = git_hash

    name = (
        config.module._target_.split(".")[-1]
        + "_"
        + config.datamodule._target_.split(".")[-1]
    )
    config_dict.update(extra_configs)
    config_dict.update(config_dict["datamodule"])

    print("=config dict logged is ", config_dict)

    try:
        wandb_logger = WandbLogger(
            name=name,
            config=config_dict,
            settings={"start_method": "fork"},
            **config.wandb,
        )
    except Exception as e:
        print(f"exception: {e}")
        print("starting wandb in offline mode. To sync logs run")
        print(f"wandb sync {job_logs_dir}")
        os.environ["WANDB_MODE"] = "offline"
        wandb_logger = WandbLogger(
            name=name,
            config=config_dict,
            settings={"start_method": "fork"},
            **config.wandb,
        )
    return wandb_logger


def get_git_hash() -> Optional[str]:
    try:
        repo = git.Repo(search_parent_directories=True)
        sha = repo.head.object.hexsha
        return sha
    except:
        print("not able to find git hash")


@rank_zero_only
def print_config(
    config: DictConfig,
    resolve: bool = True,
) -> None:
    """Saves and prints content of DictConfig
    Args:
        config (DictConfig): Configuration composed by Hydra.
        resolve (bool, optional): Whether to resolve reference fields of DictConfig.
    """
    run_configs = OmegaConf.to_yaml(config, resolve=resolve)
    # try:
    #     git_hash = git
    #     run_configs =
    # except:
    #     print("not storing git hash")
    #     pass
    print(run_configs)
    with open("run_configs.yaml", "w") as f:
        OmegaConf.save(config=config, f=f)


def log_job_info(log: logging.Logger):
    """Logs info about the job directory and SLURM job id"""
    job_logs_dir = os.getcwd()
    log.info(f"Logging to {job_logs_dir}")
    job_id = "local"

    try:
        job_env = submitit.JobEnvironment()
        job_id = job_env.job_id
    except RuntimeError:
        pass

    log.info(f"job id {job_id}")


def find_existing_checkpoint(dirpath: str) -> Optional[str]:
    """Searches dirpath for an existing model checkpoint.
    If found, returns its path.
    """
    ckpts = list(Path(dirpath).rglob("*.ckpt"))
    if ckpts:
        ckpt = str(ckpts[-1])
        print(f"resuming from existing checkpoint: {ckpt}")
        return ckpt
    return None
