from contextlib import contextmanager
import logging
import os, sys
from termcolor import colored
import copy
import numpy as np
import torch
import random
import torchvision
import time
import torch_pruning as tp


def set_random_seed(num):
    random.seed(num)
    np.random.seed(num)
    torch.manual_seed(num)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(num)


def print_args(title, args):
    print(f"------------------------ {title} ------------------------", flush=True)
    str_list = []
    for arg in vars(args):
        dots = "." * (48 - len(arg))
        str_list.append("  {} {} {}".format(arg, dots, getattr(args, arg)))
    for arg in sorted(str_list, key=lambda x: x.lower()):
        print(arg, flush=True)
    print(f"--------------------- end of {title} ---------------------", flush=True)


def log_args(title, args, logger):
    logger.info(f"------------------------ {title} ------------------------")
    str_list = []
    for arg in vars(args):
        dots = "." * (48 - len(arg))
        str_list.append("  {} {} {}".format(arg, dots, getattr(args, arg)))
    for arg in sorted(str_list, key=lambda x: x.lower()):
        logger.info(arg)
    logger.info(f"--------------------- end of {title} ---------------------")


def set_logger(args, name=""):
    logger = logging.getLogger("train_logger")
    logger.setLevel(logging.DEBUG)

    logger2 = logging.getLogger("result_logger")
    logger2.setLevel(logging.DEBUG)

    formatter = logging.Formatter(
        "%(asctime)s - %(levelname)s: - %(message)s", datefmt="%m-%d %H:%M"
    )
    args.ckpt_save_dir = os.path.join(
        args.ckpt_save_dir, time.strftime(f"%m_%d_%H_%M_%S", time.localtime())
    )
    if not os.path.exists(args.ckpt_save_dir):
        os.makedirs(args.ckpt_save_dir)

    fh = logging.FileHandler(f"{args.ckpt_save_dir}/info.log")
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)

    fh2 = logging.FileHandler(f"{args.ckpt_save_dir}/result.log")
    fh2.setLevel(logging.INFO)
    fh2.setFormatter(formatter)

    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    logger.addHandler(fh)
    logger2.addHandler(ch)
    logger2.addHandler(fh2)

    logger.info("PyThon  version : {}".format(sys.version.replace("\n", " ")))
    logger.info("PyTorch version : {}".format(torch.__version__))
    logger.info("cuDNN   version : {}".format(torch.backends.cudnn.version()))
    logger.info("Vision  version : {}".format(torchvision.__version__))
    log_args("arg_list", args, logger)
    log_args("arg_list", args, logger2)


def load_model(model, save_dir):
    logger = logging.getLogger("train_logger")
    if os.path.isfile(save_dir):
        logger.info("=> loading chechpoint '{}'".format(save_dir))
        checkpoint = torch.load(save_dir, map_location="cpu")

        tp.load_state_dict(model, state_dict=checkpoint)
        logger.info("=> loaded pruned model")
    else:
        print("=> no checkpoint found at '{}'".format(save_dir))
        raise FileNotFoundError(f"Checkpoint not found at {save_dir}")
    return model


def save_model(model, save_dir, args):
    model.zero_grad()
    state_dict = tp.state_dict(model)  # the pruned model, e.g. a resnet-18-half
    if args.save_model:
        torch.save(state_dict, save_dir)
    return state_dict

