import os
import json
import logging
from datetime import datetime
from typing import Dict

import torch
import transformers
from transformers import PreTrainedModel
import torch.distributed as dist


def add_filehandler(logger, output):
    if output.endswith(".txt") or output.endswith(".log"):
        filename = output
    else:
        now = datetime.datetime.now()
        filename = os.path.join(output, now.strftime("log-%m-%d-%H-%M-%S.log"))

    os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True)
    fh = logging.FileHandler(filename)
    fh.setLevel(logging.DEBUG)
    formatter = logging.Formatter(
        "[%(asctime)s] [%(levelname)s] "
        "[%(filename)s:%(lineno)d:%(funcName)s] %(message)s"
    )
    fh.setFormatter(formatter)
    logger.addHandler(fh)

def set_seed(seed):
    seed = seed + dist.get_rank() if dist.is_initialized() else seed
    transformers.set_seed(seed)

def get_global_rank():
    return dist.get_rank() if dist.is_initialized() else 0

def get_global_size():
    return dist.get_world_size() if dist.is_initialized() else 1

def set_no_grad(
    model: torch.nn.Module,
    logger=None,
):
    for name, param in model.named_parameters():
        if "lora" in name:
            param.requires_grad = True
        else:
            param.requires_grad = False
    
    trainable_params = 0
    all_params = 0
    for name, param in model.named_parameters():
        num_params = param.numel()
        if num_params == 0 and hasattr(param, "ds_numel"):
            num_params = param.ds_numel
        
        all_params += num_params
        if param.requires_grad:
            trainable_params += num_params
            if logger is not None:
                logger.info(f"Trainable: {name}")
        
    if logger is not None:
        logger.info(
            f"Trainable params: {trainable_params:,d}, All params: {all_params:,d}, trainable: {100 * trainable_params/all_params:.2f}"
        )

def to_device(
    obj: Dict[str, torch.Tensor],
    device: torch.device,
):
    return {k: v.to(device=device) if hasattr(v, "to") else v for k, v in obj.items()}


def maybe_zero_3(param, ignore_status=False, name=None, logger=None):
    from deepspeed import zero
    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus

    if hasattr(param, "ds_id"):
        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
            if not ignore_status:
                logger.warning(
                    f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}"
                )
        with zero.GatheredParameters([param]):
            param = param.data.detach().cpu().clone()
    else:
        param = param.detach().cpu().clone()
    return param


def get_lora_param_maybe_zero_3(named_params):
    valid_keys = ["lora"]
    to_return = {k: v for k, v in named_params if any([x in k for x in valid_keys])}
    to_return = {
        k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
    }
    return to_return


def save_pretrain(
    model: PreTrainedModel, output_dir: str, epoch: int, step: int, logger: logging.Logger
) -> None:
    state_dict = model.state_dict()
    if get_global_rank() == 0:
        return_dict = get_lora_param_maybe_zero_3(state_dict.items())
        output_dit = os.path.join(output_dir, "checkpoint")
        os.makedirs(output_dit, exist_ok=True)
        path = os.path.join(output_dit, f"checkpoint-{epoch}-{step}.pt")
        torch.save(return_dict, path)
        logger.info(f"Model saved to {path}")
        
    
def adjust_deepspeed_config(args):
    assert hasattr(args, "deepspeed_config"), "Deepspeed config not found in args"
    deepspeed_config = json.load(open(args.deepspeed_config, "r"))

    if hasattr(args, "learning_rate"):
        deepspeed_config["optimizer"]["params"]["lr"] = args.learning_rate

    if hasattr(args, "weight_decay"):
        deepspeed_config["optimizer"]["params"]["weight_decay"] = args.weight_decay

    if hasattr(args, "batch_size"):
        deepspeed_config["train_micro_batch_size_per_gpu"] = args.batch_size

    if hasattr(args, "gradient_accumulation_steps"):
        deepspeed_config["gradient_accumulation_steps"] = (
            args.gradient_accumulation_steps
        )

    if hasattr(args, "gradient_clipping"):
        deepspeed_config["gradient_clipping"] = args.gradient_clipping

    if hasattr(args, "warmup_steps"):
        if "scheduler" in deepspeed_config:
            deepspeed_config["scheduler"]["params"][
                "warmup_num_steps"
            ] = args.warmup_steps

    if hasattr(args, "num_train_steps"):
        if "scheduler" in deepspeed_config:
            deepspeed_config["scheduler"]["params"][
                "total_num_steps"
            ] = args.num_train_steps

    return deepspeed_config