import torch

import os
import logging
from omegaconf import OmegaConf, open_dict


def load_hydra_config_from_run(load_dir):
    cfg_path = os.path.join(load_dir, ".hydra/config.yaml")
    cfg = OmegaConf.load(cfg_path)
    return cfg


def makedirs(dirname):
    os.makedirs(dirname, exist_ok=True)


def get_logger(logpath, package_files=[], displaying=True, saving=True, debug=False):
    logger = logging.getLogger()
    if debug:
        level = logging.DEBUG
    else:
        level = logging.INFO

    if (logger.hasHandlers()):
        logger.handlers.clear()

    logger.setLevel(level)
    formatter = logging.Formatter('%(asctime)s - %(message)s')
    if saving:
        info_file_handler = logging.FileHandler(logpath, mode="a")
        info_file_handler.setLevel(level)
        info_file_handler.setFormatter(formatter)
        logger.addHandler(info_file_handler)
    if displaying:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(level)
        console_handler.setFormatter(formatter)
        logger.addHandler(console_handler)

    for f in package_files:
        logger.info(f)
        with open(f, "r") as package_f:
            logger.info(package_f.read())

    return logger


def restore_checkpoint(ckpt_dir, state, device, is_distill=False):
    if not os.path.exists(ckpt_dir):
        makedirs(os.path.dirname(ckpt_dir))
        logging.warning(f"No checkpoint found at {ckpt_dir}. Returned the same state as input")
        return state
    elif is_distill:
        loaded_state = torch.load(ckpt_dir, map_location=device, weights_only=False)
        state['optimizer_student'].load_state_dict(loaded_state['optimizer_student'])
        state['optimizer_fake'].load_state_dict(loaded_state['optimizer_fake'])
        state['student'].module.load_state_dict(loaded_state['student'], strict=False)
        state['fake_model'].module.load_state_dict(loaded_state['fake_model'], strict=False)
        state['ema_student'].load_state_dict(loaded_state['ema_student'])
        state['ema_fake'].load_state_dict(loaded_state['ema_fake'])
        state['step'] = loaded_state['step']
        return state
    else:
        loaded_state = torch.load(ckpt_dir, map_location=device)
        state['optimizer'].load_state_dict(loaded_state['optimizer'])
        state['model'].module.load_state_dict(loaded_state['model'], strict=False)
        state['ema'].load_state_dict(loaded_state['ema'])
        state['step'] = loaded_state['step']
        return state


def save_checkpoint(ckpt_dir, state, is_distill=False):
    if is_distill:
        saved_state = {
            'optimizer_student': state['optimizer_student'].state_dict(),
            'optimizer_fake': state['optimizer_fake'].state_dict(),
            'student': state['student'].module.state_dict(),
            'fake_model': state['fake_model'].module.state_dict(),
            'ema_student': state['ema_student'].state_dict(),
            'ema_fake': state['ema_fake'].state_dict(),
            'step': state['step']
        }
    else:
        saved_state = {
            'optimizer': state['optimizer'].state_dict(),
            'model': state['model'].module.state_dict(),
            'ema': state['ema'].state_dict(),
            'step': state['step']
        }
    torch.save(saved_state, ckpt_dir)