
import argparse
import torch
import os

def save_config(args: argparse.Namespace, save_dir: str = None):
    import json
    config = {}
    for key, item in args._get_kwargs():
        config[key] = item
    out_path = os.path.join(save_dir, f"args.json")
    with open(out_path, 'w') as outfile:
        json.dump(config, outfile, indent=4)

def save_model(model, optimizer, args, epoch, save_file):
    print('==> Saving...')
    state = {
        'args': args,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    torch.save(state, save_file)
    del state

def save_model2(model_dict, optimizer, args, epoch, save_file):
    print('==> Saving...')
    assert isinstance(model_dict, dict)
    state = {
        'args': args,
        'model': {n:m.state_dict() for n,m in model_dict.items()},
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    torch.save(state, save_file)
    del state


import logging
def get_logger(logpath, displaying=True, saving=True, debug=False):
    name = logpath.split("/")[-1].split(".")[0]
    logger = logging.getLogger(name)
    if debug:
        level = logging.DEBUG
    else:
        level = logging.INFO
    logger.setLevel(level)
    if saving:
        info_file_handler = logging.FileHandler(logpath, mode="w+")
        info_file_handler.setLevel(level)
        # formatter = logging.Formatter('%(asctime)s : %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
        formatter = logging.Formatter('%(asctime)s %(name)s : %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
        info_file_handler.setFormatter(formatter)
        logger.addHandler(info_file_handler)
    if displaying:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(level)
        logger.addHandler(console_handler)

    return logger


#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import torch
from torchvision import utils as vutils
# usage: save_image_tensor(data_t1[:1,...],"mnist.jpg")
def save_image_tensor(input_tensor: torch.Tensor, filename):

    assert (len(input_tensor.shape) == 4 and input_tensor.shape[0] == 1)

    input_tensor = input_tensor.clone().detach()

    input_tensor = input_tensor.to(torch.device('cpu'))
    # input_tensor = unnormalize(input_tensor)
    vutils.save_image(input_tensor, filename)
