# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Utility to export a training checkpoint to a lightweight release checkpoint.
"""

import typing as tp
from pathlib import Path

import torch
from omegaconf import OmegaConf

from audiocraft import __version__


def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
    """Export only the best state from the given EnCodec checkpoint. This
    should be used if you trained your own EnCodec model.
    """
    pkg = torch.load(checkpoint_path, "cpu")
    new_pkg = {
        "best_state": pkg["best_state"]["model"],
        "xp.cfg": OmegaConf.to_yaml(pkg["xp.cfg"]),
        "version": __version__,
        "exported": True,
    }
    Path(out_file).parent.mkdir(exist_ok=True, parents=True)
    torch.save(new_pkg, out_file)
    return out_file


def export_pretrained_compression_model(
    pretrained_encodec: str, out_file: tp.Union[Path, str]
):
    """Export a compression model (potentially EnCodec) from a pretrained model.
    This is required for packaging the audio tokenizer along a MusicGen or AudioGen model.
    Do not include the //pretrained/ prefix. For instance if you trained a model
    with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`.

    In that case, this will not actually include a copy of the model, simply the reference
    to the model used.
    """
    if Path(pretrained_encodec).exists():
        pkg = torch.load(pretrained_encodec)
        assert "best_state" in pkg
        assert "xp.cfg" in pkg
        assert "version" in pkg
        assert "exported" in pkg
    else:
        pkg = {
            "pretrained": pretrained_encodec,
            "exported": True,
            "version": __version__,
        }
    Path(out_file).parent.mkdir(exist_ok=True, parents=True)
    torch.save(pkg, out_file)


def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
    """Export only the best state from the given MusicGen or AudioGen checkpoint."""
    pkg = torch.load(checkpoint_path, "cpu")
    if pkg["fsdp_best_state"]:
        if "long_model" in pkg["fsdp_best_state"]:
            best_state = pkg["fsdp_best_state"]["long_model"]
        elif "short_model" in pkg["fsdp_best_state"]:
            best_state = pkg["fsdp_best_state"]["short_model"]
        else:
            best_state = pkg["fsdp_best_state"]["model"]
    else:
        assert pkg["best_state"]
        best_state = pkg["best_state"]["model"]
    new_pkg = {
        "best_state": best_state,
        "xp.cfg": OmegaConf.to_yaml(pkg["xp.cfg"]),
        "version": __version__,
        "exported": True,
    }

    Path(out_file).parent.mkdir(exist_ok=True, parents=True)
    torch.save(new_pkg, out_file)
    return out_file
