import torch
# from torch._six import inf
from math import inf
import logging
from termcolor import colored
import sys
import os
import time


def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor:
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    norm_type = float(norm_type)
    if len(parameters) == 0:
        return torch.tensor(0.)
    device = parameters[0].grad.device
    if norm_type == inf:
        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
    else:
        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(),
                                                        norm_type).to(device) for p in parameters]), norm_type)
    return total_norm

class NativeScalerWithGradNormCount:
    state_dict_key = "amp_scaler"

    def __init__(self):
        self._scaler = torch.cuda.amp.GradScaler()

    def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True,retain_graph=False):
        self._scaler.scale(loss).backward(create_graph=create_graph, retain_graph=retain_graph)
        if update_grad:
            if clip_grad is not None:
                assert parameters is not None
                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
            else:
                self._scaler.unscale_(optimizer)
                norm = ampscaler_get_grad_norm(parameters)
            self._scaler.step(optimizer)
            self._scaler.update()
        else:
            norm = None
        return norm

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)


def create_logger(output_dir, dist_rank=0, name=''):
    # create logger
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
    logger.propagate = False

    # create formatter
    fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
    color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
                colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'

    # create console handlers for master process
    if dist_rank == 0:
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setLevel(logging.DEBUG)
        console_handler.setFormatter(
            logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
        logger.addHandler(console_handler)

    # create file handlers
    file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}_{int(time.time())}.txt'), mode='a')
    file_handler.setLevel(logging.DEBUG)
    file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
    logger.addHandler(file_handler)

    return logger

    
def use_old_forward(module: nn.Module, recurse=False):
    if hasattr(module, '_old_forward'):
        module._new_forward = module.forward
        module.forward = module._old_forward
    
    if recurse:
        for child in module.children():
            use_old_forward(child, recurse)


def use_new_forward(module: nn.Module, recurse=False):
    if hasattr(module, '_new_forward'):
        module.forward = module._new_forward
        delattr(module, "_new_forward")
    
    if recurse:
        for child in module.children():
            use_new_forward(child, recurse)

def init_learn_sparsity(compress_layers, sparsity_step=0.01, prune_n=0, prune_m=0, blocksize=-1, sigmoid_smooth=False, lora_rank=-1, lora_alpha=1):
    for layer_name in compress_layers:
        sparse_layer = compress_layers[layer_name]
        sparse_layer.init_learn_sparsity(sparsity_step, prune_n, prune_m, blocksize, sigmoid_smooth, lora_rank, lora_alpha)


def finish_learn_sparsity(compress_layers):
    for layer_name in compress_layers:
        sparse_layer = compress_layers[layer_name]
        sparse_layer.finish_learn_sparsity()

def get_sparsity(compress_layers):
    total_param = sum([compress_layers[layer_name].param_num for layer_name in compress_layers])
    sparsity = 0
    for layer_name in compress_layers:
        sparse_layer = compress_layers[layer_name]
        sparsity += sparse_layer.sparsity * (sparse_layer.param_num / total_param)
    return sparsity