# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

# Note (rwaleffe): This is a temporary file for hybrid mamba-transformer model checkpoint conversion.
# This functionality should be integrated with the megatron core checkpoint loader/saver.


import copy
import os
import re
import shutil
from collections import OrderedDict

import torch
import argparse


tp_split_dim = {
    'word_embeddings.weight': 0,
    'norm.weight': -1,
    'final_norm.weight': -1,
    'output_layer.weight': 0,
    # mamba1/2
    'A_log': 0,
    'D': 0,
    'dt_bias': 0,
    'in_proj.weight': 0,
    'conv1d.weight': 0,
    'conv1d.bias': 0,
    'x_proj.weight': 1,
    'dt_proj.weight': 0,
    'dt_proj.bias': 0,
    'out_proj.weight': 1,
    'mixer.norm.weight': 0,
    # mlp
    'linear_fc1.layer_norm_weight': -1,
    'linear_fc1.weight': 0,
    'linear_fc2.weight': 1,
    # attention
    'self_attention.linear_proj.weight': 1,
    'self_attention.linear_qkv.layer_norm_weight': -1,
    'self_attention.linear_qkv.weight': 0,
}


def get_split_dim(tensor_name):
    # norm.weight will match tensor_name of mixer.norm.weight and norm.weight, need to distinguish
    if 'norm.weight' in tensor_name:
        if 'mixer.norm.weight' in tensor_name:
            return tp_split_dim['mixer.norm.weight']
        else:
            return tp_split_dim['norm.weight']

    for key in tp_split_dim.keys():
        if key in tensor_name:
            return tp_split_dim[key]
    raise Exception("Unknown tensor name {}".format(tensor_name))


def combine_tp_tensors(params, key, dim, tensors):
    tp_size = len(tensors)

    if 'mixer.in_proj.weight' in key and params.mamba_version == 1:
        xs = []; zs = []
        for tensor in tensors:
            x, z = torch.split(tensor, [params.mamba_d_inner//tp_size,
                                        params.mamba_d_inner//tp_size], dim=dim)
            xs.append(x); zs.append(z)
        return torch.cat([torch.cat(xs, dim=dim), torch.cat(zs, dim=dim)], dim=dim)

    elif 'mixer.in_proj.weight' in key and params.mamba_version == 2:
        xs = []; zs = []; Bs = []; Cs = []; dts = []
        for tensor in tensors:
            x, z, B, C, dt = torch.split(tensor, [params.mamba_d_inner // tp_size,
                                                  params.mamba_d_inner // tp_size,
                                                  (params.mamba2_n_groups // tp_size) * args.mamba_d_state,
                                                  (params.mamba2_n_groups // tp_size) * args.mamba_d_state,
                                                  params.mamba2_n_heads // tp_size], dim=dim)
            xs.append(x); zs.append(z); Bs.append(B); Cs.append(C); dts.append(dt)

        for ii in range(len(Bs)):
            Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state, Bs[ii].shape[-1]))
            Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state, Cs[ii].shape[-1]))
        B = torch.cat(Bs, dim=dim); C = torch.cat(Cs, dim=dim)
        x = torch.cat(xs, dim=dim); z = torch.cat(zs, dim=dim); dt = torch.cat(dts, dim=dim)

        return torch.cat([x, z, B.flatten(0, 1), C.flatten(0, 1), dt], dim=dim)

    elif 'mixer.conv1d' in key and params.mamba_version == 2:
        xs = []; Bs = []; Cs = []
        for tensor in tensors:
            x, B, C = torch.split(tensor, [params.mamba_d_inner//tp_size,
                                           (params.mamba2_n_groups // tp_size) * params.mamba_d_state,
                                           (params.mamba2_n_groups // tp_size) * params.mamba_d_state], dim=dim)
            xs.append(x); Bs.append(B); Cs.append(C)

        for ii in range(len(Bs)):
            if 'weight' in key:
                Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state, Bs[ii].shape[-2], Bs[ii].shape[-1]))
                Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state, Cs[ii].shape[-2], Cs[ii].shape[-1]))
            elif 'bias' in key:
                Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state))
                Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state))
            else:
                raise Exception("Unknown key")
        B = torch.cat(Bs, dim=dim); C = torch.cat(Cs, dim=dim)
        x = torch.cat(xs, dim=dim)

        return torch.cat([x, B.flatten(0, 1), C.flatten(0, 1)], dim=dim)

    else:
        return torch.cat(tensors, dim=dim)


def split_tensor_for_tp(params, key, dim, tensor):
    tp_size = params.target_tp_size
    tensor_sliced = []

    if 'mixer.in_proj.weight' in key and params.mamba_version == 1:
        x, z = torch.split(tensor, [params.mamba_d_inner, params.mamba_d_inner], dim=dim)
        x_sliced = torch.chunk(x, tp_size, dim=dim)
        z_sliced = torch.chunk(z, tp_size, dim=dim)
        for (x, z) in zip(x_sliced, z_sliced):
            tensor_sliced.append(torch.cat((x, z), dim=dim))

    elif 'mixer.in_proj.weight' in key and params.mamba_version == 2:
        x, z, B, C, dt = torch.split(tensor, [params.mamba_d_inner, params.mamba_d_inner,
                                                      params.mamba2_n_groups * params.mamba_d_state,
                                                      params.mamba2_n_groups * params.mamba_d_state,
                                                      params.mamba2_n_heads], dim=dim)
        B = torch.reshape(B, (-1, params.mamba_d_state, B.shape[-1]))
        C = torch.reshape(C, (-1, params.mamba_d_state, C.shape[-1]))

        B_sliced = torch.chunk(B, tp_size, dim=dim)
        C_sliced = torch.chunk(C, tp_size, dim=dim)
        x_sliced = torch.chunk(x, tp_size, dim=dim)
        z_sliced = torch.chunk(z, tp_size, dim=dim)
        dt_sliced = torch.chunk(dt, tp_size, dim=dim)

        tensor_sliced = []
        for (x, z, B, C, dt) in zip(x_sliced, z_sliced, B_sliced, C_sliced, dt_sliced):
            tensor_sliced.append(torch.cat((x, z, B.flatten(0, 1), C.flatten(0, 1), dt), dim=dim))

    elif 'mixer.conv1d' in key and params.mamba_version == 2:
        x, B, C = torch.split(tensor, [params.mamba_d_inner,
                                               params.mamba2_n_groups * params.mamba_d_state,
                                               params.mamba2_n_groups * params.mamba_d_state], dim=dim)
        if 'weight' in key:
            B = torch.reshape(B, (-1, params.mamba_d_state, B.shape[-2], B.shape[-1]))
            C = torch.reshape(C, (-1, params.mamba_d_state, C.shape[-2], C.shape[-1]))
        elif 'bias' in key:
            B = torch.reshape(B, (-1, params.mamba_d_state))
            C = torch.reshape(C, (-1, params.mamba_d_state))
        else:
            raise Exception("Unknown key")

        B_sliced = torch.chunk(B, tp_size, dim=dim)
        C_sliced = torch.chunk(C, tp_size, dim=dim)
        x_sliced = torch.chunk(x, tp_size, dim=dim)

        tensor_sliced = []
        for (x, B, C) in zip(x_sliced, B_sliced, C_sliced):
            tensor_sliced.append(torch.cat((x, B.flatten(0, 1), C.flatten(0, 1)), dim=dim))

    else:
        tensor_sliced = torch.chunk(tensor, tp_size, dim=dim)

    return tensor_sliced


def finalize_checkpoint(sample_model, model, params, verbose=False):
    # make sure the rest of the checkpoint is how we want it from the original (i.e., other than the 'model')
    reset_iterations = params.reset_iterations

    # checkpoint 'args'
    model['args'] = copy.deepcopy(sample_model['args'])
    model['args'].tensor_model_parallel_size = params.target_tp_size
    model['args'].pipeline_model_parallel_size = params.target_pp_size
    if reset_iterations:
        model['args'].iteration = 0
        model['args'].consumed_valid_samples = 0
        model['args'].consumed_train_samples = 0
        model['args'].train_iters = 0
        model['args'].train_samples = 0

    # checkpoint 'checkpoint_version'
    model['checkpoint_version'] = copy.deepcopy(sample_model['checkpoint_version'])

    # checkpoint 'iteration'
    model['iteration'] = copy.deepcopy(sample_model['iteration'])
    if reset_iterations:
        model['iteration'] = 0

    # checkpoint 'optimizer'
    # ignore

    # checkpoint 'opt_param_scheduler'
    if 'opt_param_scheduler' in sample_model.keys():
        model['opt_param_scheduler'] = copy.deepcopy(sample_model['opt_param_scheduler'])

    # checkpoint 'rng_state'
    model['rng_state'] = copy.deepcopy(sample_model['rng_state'])

    # report on argument difference
    if verbose:
        original_args = sample_model['args'].__dict__
        final_args = model['args'].__dict__
        for key in original_args:
            if key in final_args:
                if final_args[key] != original_args[key]:
                    print("KEY MISMATCH: {}".format(key))
                    print("\toriginal: {}\n\tfinal: {}".format(original_args[key], final_args[key]))
            else:
                print("KEY MISSING from final: {}, value {}".format(key, original_args[key]))
        print("")
        for key in final_args:
            if key not in original_args:
                print("KEY ADDED to final: {}, value {}".format(key, final_args[key]))

    return model


def main(args):
    print("\n====RUNNING CHECKPOINT CONVERSION====\n")

    args.mamba_d_inner = args.d_model * 2
    args.mamba2_n_heads = args.mamba_d_inner // args.mamba2_head_dim

    # get the latest iteration
    tracker_filename = os.path.join(args.load_dir, 'latest_checkpointed_iteration.txt')
    with open(tracker_filename, 'r') as f:
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
            raise Exception("")
    out_iteration = iteration if not args.reset_iterations else 0

    # get model directory and model parallel ranks
    input_model_dir = os.path.join(args.load_dir, 'iter_{:07d}'.format(iteration))
    input_sub_models = os.listdir(input_model_dir)
    # input_sub_models = sorted(input_sub_models, key=lambda x: int(re.search(r'\d+', x).group()))

    # load one of the model parallel ranks to get arguments
    sample_model_file = os.path.join(input_model_dir, input_sub_models[0], "model_optim_rng.pt")
    sample_model = torch.load(sample_model_file)
    print(f"Sample model {sample_model_file} is loaded.\n")

    # input tensor and pipeline parallel size
    input_tp_rank = sample_model['args'].tensor_model_parallel_size
    input_pp_rank = sample_model['args'].pipeline_model_parallel_size
    num_layers_per_pipeline_rank = sample_model['args'].num_layers // input_pp_rank

    # construct full model
    full_model = OrderedDict()
    for pp in range(input_pp_rank):
        print("[INFO] Processing input pipeline rank {}".format(pp))
        tp_models = []
        for tp in range(input_tp_rank):
            dir_name = "mp_rank_{:02d}".format(tp)
            if input_pp_rank > 1:
                dir_name += "_{:03d}".format(pp)
            model_file = os.path.join(input_model_dir, dir_name, "model_optim_rng.pt")

            tp_models.append(torch.load(model_file))
            print(f"Model {model_file} is loaded.")

        if input_tp_rank > 1:
            combined_tp_model = OrderedDict()
            for ii, (key, original_tensor) in enumerate(tp_models[0]['model'].items()):
                if "_extra_state" in key:
                    combined_tp_model[key] = original_tensor
                    continue

                split_dim = get_split_dim(key)
                original_shape = list(original_tensor.shape)
                combined_shape = copy.deepcopy(original_shape)
                combined_shape[split_dim] *= input_tp_rank
                # print("{}, {}, {}".format(ii, key, split_dim))

                if split_dim != -1:
                    # slice together model
                    # print("\tshape mismatch: original {}, combined {}".format(original_shape, combined_shape))
                    combined_tensor = combine_tp_tensors(args, key, split_dim,
                                                    [tp_models[jj]['model'][key].cpu() for jj in range(input_tp_rank)])
                    combined_tp_model[key] = combined_tensor
                else:
                    # copy model
                    combined_tp_model[key] = original_tensor
        else:
            combined_tp_model = tp_models[0]['model']
        # print("Combined tp model: {}".format(combined_tp_model.keys()))

        for ii, (key, original_tensor) in enumerate(combined_tp_model.items()):
            try:
                layer_num = int(re.findall(r'\d+', key)[0])
                new_key = key.replace(str(layer_num), str(layer_num + pp*num_layers_per_pipeline_rank), 1)
            except:
                new_key = key
            full_model[new_key] = original_tensor
    # print("Combined model: {}".format(full_model.keys()))
    print("\n[INFO] Loaded combined model\n")

    # sort by layer
    # full_model_sorted = dict(sorted(people.items(), key=lambda item: item[1]))

    # create new split model
    pp_offset = 0
    num_layers_per_pipeline_rank = sample_model['args'].num_layers // args.target_pp_size

    for pp in range(args.target_pp_size):
        print("[INFO] Processing output pipeline rank {}".format(pp))
        tp_models = []
        for ii in range(args.target_tp_size):
            tp_models.append({'model': OrderedDict()})

        for ii, (key, original_tensor) in enumerate(full_model.items()):
            try:
                layer_num = int(re.findall(r'\d+', key)[0])
                if layer_num >= num_layers_per_pipeline_rank * (pp+1):
                    break
                new_key = key.replace(str(layer_num), str(layer_num - (pp * num_layers_per_pipeline_rank)), 1)
            except:
                new_key = key

            if ii < pp_offset:
                continue
            else:
                pp_offset += 1

            if "_extra_state" in new_key:
                # copy
                for jj in range(args.target_tp_size):
                    tp_models[jj]['model'][new_key] = original_tensor
                continue

            split_dim = get_split_dim(new_key)
            original_shape = list(original_tensor.shape)
            v0 = original_shape[split_dim]
            split_size = v0 // args.target_tp_size
            split_shape = copy.deepcopy(original_shape)
            split_shape[split_dim] = split_size
            # print("{}, {}, {}".format(ii, new_key, split_dim))

            if split_dim != -1:
                # split model
                # print("\tshape mismatch: original {}, combined {}".format(original_shape, split_shape))
                tensor_sliced = split_tensor_for_tp(args, new_key, split_dim, original_tensor)
                for jj in range(args.target_tp_size):
                    tp_models[jj]['model'][new_key] = tensor_sliced[jj]
            else:
                # copy model
                for jj in range(args.target_tp_size):
                    tp_models[jj]['model'][new_key] = original_tensor
        # print(tp_models[0]['model'].keys())

        for tp in range(args.target_tp_size):
            dir_name = "mp_rank_{:02d}".format(tp)
            if args.target_pp_size > 1:
                dir_name += "_{:03d}".format(pp)

            model = finalize_checkpoint(sample_model, tp_models[tp], args, verbose=False)

            save_dir = os.path.join(args.save_dir, 'iter_{:07d}'.format(out_iteration), dir_name)
            os.makedirs(save_dir, exist_ok=True)
            model_file = os.path.join(save_dir, "model_optim_rng.pt")
            torch.save(model, model_file)
            print(f"Model {model_file} is saved.")

    # shutil.copyfile(tracker_filename, os.path.join(args.save_dir, 'latest_checkpointed_iteration.txt'))
    tracker_filename = os.path.join(args.save_dir, 'latest_checkpointed_iteration.txt')
    with open(tracker_filename, 'w') as f:
        f.write(str(out_iteration))


if __name__ == "__main__":
    # example run command:
    # python hybrid_conversion.py
    # --load-dir mamba2-840m-test/checkpoints/
    # --save-dir mamba2-840m-test-conversion/checkpoints/
    # --target-pp-size 1
    # --target-tp-size 1

    parser = argparse.ArgumentParser()
    parser.add_argument('--load-dir', type=str)
    parser.add_argument('--save-dir', type=str)
    parser.add_argument('--target-tp-size', type=int, default=1)
    parser.add_argument('--target-pp-size', type=int, default=1)
    parser.add_argument('--reset-iterations', action='store_true')

    parser.add_argument('--d-model', type=int, default=4096)
    parser.add_argument('--mamba-version', type=int, default=2)
    parser.add_argument('--mamba-d-state', type=int, default=128)
    parser.add_argument('--mamba2-n-groups', type=int, default=8)
    parser.add_argument('--mamba2-head-dim', type=int, default=64)

    args = parser.parse_args()

    main(args)