from omegaconf import DictConfig, OmegaConf
from misc.attr_dict import AttrDict

import pprint

import logging
import os
import json

def initialize_config(cfg: DictConfig, mode: str):
    
    # convert the config to AttrDict
    cfg = OmegaConf.to_container(cfg)
    cfg = AttrDict(cfg)

    # assert the config and infer
    cfg = cfg.config
    cfg.MODE = mode
    cfg = infer_and_assert_hydra_config(cfg)
    
    return cfg


def infer_and_assert_hydra_config(cfg):

    # if cfg.MODE == "extract":
    #     assert cfg.DISTRIBUTED.NUM_NODES == 1
    #     cfg.DISTRIBUTED.NUM_PROC_PER_NODE = (
    #         torch.cuda.device_count()
    #     )
    #     assert len(cfg.LOAD_FROM) > 0
    
    # distributed
    cfg.DISTRIBUTED.WORLD_SIZE = int(
        cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE
    )
    
    cfg.TRAIN.BATCH_SIZE.batch_size_per_proc = int(
        cfg.TRAIN.BATCH_SIZE.effective_batch_size / cfg.DISTRIBUTED.WORLD_SIZE
    )

    cfg.TEST.BATCH_SIZE.batch_size_per_proc = int(
        cfg.TEST.BATCH_SIZE.effective_batch_size / cfg.DISTRIBUTED.WORLD_SIZE
    )

    # set paths
    assert "PYTHONPATH" in os.environ
    cfg.PROJ_ROOT = os.environ["PYTHONPATH"]
    # cfg.CKPT_PATH = os.path.join(cfg.PROJ_ROOT, "ckpt/pretrain")
    # cfg.LOG_PATH = os.path.join(cfg.PROJ_ROOT, "logs/pretrain")
    
    if cfg.MODE in ["pretrain", "extract"]:
        cfg.CKPT_PATH = os.path.join(cfg.PROJ_ROOT, "ckpt/pretrain")
        cfg.LOG_PATH = os.path.join(cfg.PROJ_ROOT, "logs/pretrain")
    elif cfg.MODE in ["finetune"]:
        cfg.PRETRAINED_CKPT_PATH = os.path.join(cfg.PROJ_ROOT, "ckpt/pretrain")
        cfg.CKPT_PATH = os.path.join(cfg.PROJ_ROOT, "ckpt/finetune")
        cfg.LOG_PATH = os.path.join(cfg.PROJ_ROOT, "logs/finetune")        
    else:
        raise NotImplementedError

    assert cfg.DATA_PATH is not None
    # if cfg.DATASET in ["movienet", "BBC", "OVSD"]:
    if cfg.DATASET in ["movienet", "BBC", "OVSD", "predict"]:
        cfg.IMG_PATH = os.path.join(cfg.DATA_PATH, "240P_frames")
        cfg.ANNO_PATH = os.path.join(cfg.DATA_PATH, "anno")
        cfg.FEAT_PATH = os.path.join(cfg.DATA_PATH, "features")        
    else:
        raise NotImplementedError
    
    # cfg.TRAINER.gpus = cfg.DISTRIBUTED.NUM_PROC_PER_NODE
    cfg.TRAINER.devices = cfg.DISTRIBUTED.NUM_PROC_PER_NODE
    cfg.TRAINER.num_nodes = cfg.DISTRIBUTED.NUM_NODES
    if cfg.MODEL.use_sync_bn:
        cfg.TRAINER.sync_batchnorm = True

    # auto scale learning rate
    cfg.TRAIN.OPTIMIZER.lr.scaled_lr = cfg.TRAIN.OPTIMIZER.lr.base_lr
    if cfg.TRAIN.OPTIMIZER.lr.auto_scale:
        cfg.TRAIN.OPTIMIZER.lr.scaled_lr = (
            cfg.TRAIN.OPTIMIZER.lr.base_lr
            * cfg.TRAIN.BATCH_SIZE.effective_batch_size
            / float(cfg.TRAIN.OPTIMIZER.lr.base_lr_batch_size)
        )
        
    return cfg


def print_cfg(cfg):
    """
    Supports printing both Hydra DictConfig and also the AttrDict config
    """
    logging.info("Training with config:")
    logging.getLogger().setLevel(logging.DEBUG)
    if isinstance(cfg, DictConfig):
        logging.info(cfg.pretty())
    else:
        logging.info(pprint.pformat(cfg))


def save_config_to_disk(cfg):
    
    filename = os.path.join(cfg.CKPT_PATH, cfg.EXPR_NAME, "config.json")
    
    with open(filename, "w") as fopen:
        json.dump(cfg, fopen, indent=4, ensure_ascii=False)
        fopen.flush()
        
    logging.info(f"Saved Config Data to File: {filename}")