# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import os
import torch
import random
import shutil
import numpy as np
from transformers import set_seed, AutoTokenizer
import json
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
import torch.nn as nn


def print_rank_0(msg, rank=0):
    if rank <= 0:
        print(msg)


def to_device(batch, device):
    output = {}
    for k, v in batch.items():
        try:
            output[k] = v.to(device)
        except Exception as e:
            print("to device error: {}".format(e))
            assert 0
            output[k] = v
    return output

# def to_device(batch, device):
#     output = {}
#     for k, v in batch.items():
#         try:
#             print(f"[to_device] Moving key '{k}': shape={getattr(v, 'shape', None)}, dtype={getattr(v, 'dtype', None)}")
#             output[k] = v.to(device)
#         except Exception as e:
#             print(f"[ERROR] Failed to move key '{k}' to {device}")
#             print(f"Type: {type(v)}, Shape: {getattr(v, 'shape', 'N/A')}, Dtype: {getattr(v, 'dtype', 'N/A')}")
#             print(f"Exception: {e}")
#             raise  # Let it crash naturally for a proper traceback
#     return output


class MovingAverage:
    def __init__(self):
        self.count = 0
        self.total = 0
        self.mean = 0

    def update(self, num):
        self.total += num
        self.count += 1
        self.mean = self.total / self.count

        return self.mean


def get_tokenizer(model_name_or_path, fast_tokenizer=True):
    # if "llama" in model_name_or_path.lower():
    #     from transformers.models.llama import LlamaTokenizer
    #     tokenizer = LlamaTokenizer.from_pretrained(
    #         model_name_or_path, fast_tokenizer=fast_tokenizer, token=
    #     )
    #     if tokenizer.pad_token is None:
    #         # assert tokenizer.eos_token is not None
    #         tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})
    #         # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    #         tokenizer.padding_side = "right"
    # else:
    #     tokenizer = AutoTokenizer.from_pretrained(
    #         model_name_or_path, fast_tokenizer=fast_tokenizer
    #     )
    #     tokenizer.pad_token = tokenizer.eos_token
    #     # make sure tokenizer is right pad in our logic
    #     tokenizer.padding_side = "right"

    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path, fast_tokenizer=fast_tokenizer
    )
    #if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    # make sure tokenizer is right pad in our logic
    tokenizer.padding_side = "right"
    return tokenizer


def load_hf_tokenizer(model_name_or_path, fast_tokenizer=True):
    if os.path.exists(model_name_or_path):
        # Locally tokenizer loading has some issue, so we need to force download
        model_json = os.path.join(model_name_or_path, "config.json")
        if os.path.exists(model_json):
            model_json_file = json.load(open(model_json))
            model_name = model_json_file["_name_or_path"]
            if "pythia-2.8b" in model_name.lower():
                model_name = "EleutherAI/pythia-2.8b"
            if "opt-1.3b" in model_name.lower():
                model_name = "facebook/opt-1.3b"
            if "llama-7b" in model_name.lower():
                model_name = "meta-llama/Llama-2-7b-hf"
            tokenizer = get_tokenizer(model_name, fast_tokenizer=fast_tokenizer)
    else:
        tokenizer = get_tokenizer(model_name_or_path, fast_tokenizer=fast_tokenizer)

    return tokenizer


def save_hf_format(model, tokenizer, args, sub_folder=""):
    # used to save huggingface format, so we can use it for hf.from_pretrained
    model_to_save = model.module if hasattr(model, "module") else model
    CONFIG_NAME = "config.json"
    WEIGHTS_NAME = "pytorch_model.bin"
    output_dir = os.path.join(args.output_dir, sub_folder)
    os.makedirs(output_dir, exist_ok=True)
    output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
    output_config_file = os.path.join(output_dir, CONFIG_NAME)
    save_dict = model_to_save.state_dict()
    for key in list(save_dict.keys()):
        if "lora" in key:
            del save_dict[key]
    torch.save(save_dict, output_model_file)
    model_to_save.config.to_json_file(output_config_file)
    tokenizer.save_vocabulary(output_dir)


def set_random_seed(seed):
    if seed is not None:
        set_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)


def save_code(save_dir):
    project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
    shutil.copytree(
        project_dir,
        save_dir + "/code",
        ignore=shutil.ignore_patterns(
            "output*",
            "log*",
            "result*",
            "dataset*",
            "checkpoint*",
            ".git",
            "*.pyc",
            ".idea",
            ".DS_Store",
        ),
    )


def get_all_reduce_mean(tensor):
    torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
    tensor = tensor / torch.distributed.get_world_size()
    return tensor


# This function is a modified version of code available in the from_pretrained API of HuggingFace Transformers
# The code is copied and modified from: https://github.com/huggingface/transformers/blob/5ee9693a1c77c617ebc43ef20194b6d3b674318e/src/transformers/modeling_utils.py#L498
# This function helps load a HF format checkpoint into a DeepSpeed wrapped model that has been sharded using ZeRO Stage 3
def load_state_dict_into_model(
    model_to_load=None, state_dict=None, start_prefix="", zero_stage=0
):
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, "_metadata", None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    error_msgs = []

    # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
    # so we need to apply the function recursively.
    def load(module: nn.Module, state_dict, prefix=""):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
        # Parameters of module and children will start with prefix. We can exit early if there are none in this
        # state_dict
        if len([key for key in state_dict if key.startswith(prefix)]) > 0:
            if zero_stage == 3:
                # In sharded models, each shard has only part of the full state_dict, so only gather
                # parameters that are in the current state_dict.
                named_parameters = dict(
                    module.named_parameters(prefix=prefix[:-1], recurse=False)
                )
                params_to_gather = [
                    named_parameters[k]
                    for k in state_dict.keys()
                    if k in named_parameters
                ]
                if len(params_to_gather) > 0:
                    # because zero3 puts placeholders in model params, this context
                    # manager gathers (unpartitions) the params of the current layer, then loads from
                    # the state dict and then re-partitions them again
                    with deepspeed.zero.GatheredParameters(
                        params_to_gather, modifier_rank=0
                    ):
                        if torch.distributed.get_rank() == 0:
                            module._load_from_state_dict(*args)
            else:
                module._load_from_state_dict(*args)

        for name, child in module._modules.items():
            if child is not None:
                load(child, state_dict, prefix + name + ".")

    load(model_to_load, state_dict, prefix=start_prefix)
    # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
    # it's safe to delete it.
    del state_dict

    return error_msgs


def get_optimizer_grouped_parameters(
    model,
    weight_decay,
    lora_lr=5e-4,
    no_decay_name_list=["bias", "LayerNorm.weight"],
    lora_name_list=["lora_right_weight", "lora_left_weight"],
):
    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if (
                    not any(nd in n for nd in no_decay_name_list)
                    and p.requires_grad
                    and not any(nd in n for nd in lora_name_list)
                )
            ],
            "weight_decay": weight_decay,
        },
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if (
                    not any(nd in n for nd in no_decay_name_list)
                    and p.requires_grad
                    and any(nd in n for nd in lora_name_list)
                )
            ],
            "weight_decay": weight_decay,
            "lr": lora_lr,
        },
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if (any(nd in n for nd in no_decay_name_list) and p.requires_grad)
            ],
            "weight_decay": 0.0,
        },
    ]
    if not optimizer_grouped_parameters[1]["params"]:
        optimizer_grouped_parameters.pop(1)
    return optimizer_grouped_parameters


def _z3_params_to_fetch(param_list):
    return [
        p
        for p in param_list
        if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
    ]


def moving_average(model, model_ema, beta=0.992, device=None, zero_stage=0):
    zero_stage_3 = zero_stage == 3
    with torch.no_grad():
        for param, param_ema in zip(model.parameters(), model_ema.parameters()):
            # TODO: use prefiltering for efficiency
            params_to_fetch = (
                _z3_params_to_fetch([param, param_ema]) if zero_stage_3 else []
            )
            should_gather_param = len(params_to_fetch) > 0
            with deepspeed.zero.GatheredParameters(
                params_to_fetch, enabled=should_gather_param
            ):
                data = param.data
                if device is not None:
                    data = data.to(device)
                param_ema.data.copy_(torch.lerp(data, param_ema.data, beta))


def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):
    zero_stage_3 = zero_stage == 3
    os.makedirs(save_dir, exist_ok=True)
    WEIGHTS_NAME = "pytorch_model.bin"
    output_model_file = os.path.join(save_dir, WEIGHTS_NAME)

    model_to_save = model_ema.module if hasattr(model_ema, "module") else model_ema
    if not zero_stage_3:
        if global_rank == 0:
            torch.save(model_to_save.state_dict(), output_model_file)
    else:
        output_state_dict = {}
        for k, v in model_to_save.named_parameters():
            if hasattr(v, "ds_id"):
                with deepspeed.zero.GatheredParameters(
                    _z3_params_to_fetch([v]), enabled=zero_stage_3
                ):
                    v_p = v.data.cpu()
            else:
                v_p = v.cpu()
            if global_rank == 0 and "lora" not in k:
                output_state_dict[k] = v_p
        if global_rank == 0:
            torch.save(output_state_dict, output_model_file)
        del output_state_dict
