import argparse
import collections
import os

import torch
from fairseq.file_io import PathManager


def read_model(fpath):
    with PathManager.open(fpath, "rb") as f:
        state = torch.load(
            f,
            map_location=(
                lambda s, _: torch.serialization.default_restore_location(s, "cpu")
            ),
        )
        model_params = state["model"]
        params_keys = list(model_params.keys())
    return model_params, params_keys, state


def main():
    parser = argparse.ArgumentParser(
        description="Merge checkpoints from different runs",
    )
    # fmt: off
    parser.add_argument('--in1', required=True, help='input checkpoint 1')
    parser.add_argument('--in2', required=True, help='input checkpoint 2')
    parser.add_argument('--out', required=True, help='output path')
    # fmt: on
    args = parser.parse_args()
    print(args)

    model_params1, params_keys1, state = read_model(args.in1)
    model_params2, params_keys2, _ = read_model(args.in2)
    if params_keys1 != params_keys2:
        raise ValueError("The two checkpoints have different parameters")
    averaged_params = collections.OrderedDict()
    for alpha in range(-10, 20):
        for k in params_keys1:
            p1, p2 = model_params1[k], model_params2[k]
            if isinstance(model_params1[k], torch.HalfTensor):
                p1 = p1.float()
                p2 = p2.float()
            tmp = (1 - alpha / 10) * p1 + alpha / 10 * p2
            if p1.is_floating_point():
                averaged_params[k] = tmp
            else:
                averaged_params[k] = tmp.to(torch.int)
        state["model"] = averaged_params
        path = os.path.join(args.out, f"merged{alpha}.pt")
        with PathManager.open(path, "wb") as f:
            torch.save(state, f)
        print(f"Saved to {path}")

if __name__ == "__main__":
    main()
