import torch
import json
import orjson
import orjsonl
from deepspeed.utils import safe_get_full_grad, safe_get_local_grad
from deepspeed.runtime.utils import get_global_norm

# read with orjsonl (jsonl in, list out)
def read_with_orjsonl(file_path):
    data = orjsonl.load(file_path)
    return data

# write with orjsonl (list in, jsonl out)
def write_with_orjsonl(data, output_file_path):
    # print(data[0], data[-1])
    orjsonl.save(output_file_path, data)

# write with orjsonl, extend (list in, jsonl out)
def write_with_orjsonl_extend(data, output_file_path):
    # print(data[0], data[-1])
    orjsonl.extend(output_file_path, data)

def get_nb_trainable_parameters(model, logger, torch_dtype=torch.bfloat16):
    # copied from https://github.com/huggingface/peft/blob/main/src/peft/peft_model.py
    r"""
    Returns the number of trainable parameters and the number of all parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        num_params = param.numel()
        # if using DS Zero 3 and the weights are initialized empty
        if num_params == 0 and hasattr(param, "ds_numel"):
            num_params = param.ds_numel

        # Due to the design of 4bit linear layers from bitsandbytes
        # one needs to multiply the number of parameters by 2 to get
        # the correct number of parameters
        if param.__class__.__name__ == "Params4bit":
            num_params = num_params * 2

        all_param += num_params
        if param.requires_grad:
            trainable_params += num_params
    
    logger.info(f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param}")
    logger.info(f"trainable params: {trainable_params / 1024 / 1024/ 1024:.2f} M || all params: {all_param / 1024 / 1024/ 1024 / 1024:.2f} B")
    if torch_dtype in [torch.bfloat16, torch.float16]:
        logger.info(f"trainable params memory: {trainable_params * 2 / 1024 / 1024 / 1024:.2f} GB || all params memory: {all_param * 2 / 1024 / 1024 / 1024:.2f} GB, assuming {torch_dtype}")
    elif torch_dtype == torch.float32:
        logger.info(f"trainable params memory: {trainable_params * 4 / 1024 / 1024 / 1024:.2f} GB || all params memory: {all_param * 4 / 1024 / 1024 / 1024:.2f} GB, assuming {torch_dtype}")

def get_grad_norm(model, name=None):
    grads = []
    for n, p in model.named_parameters():
        if name and not any([x in n for x in name]):
            continue
        grad = safe_get_full_grad(p)
        if grad is not None:
            grads.append(grad.detach().data.double().norm(2))
    # return sum([grad**2.0 for grad in grads])
    return get_global_norm(grads)