#!/usr/bin/env python

import os
from pathlib import Path
from typing import Dict, List

import fire
import torch

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.utils.logging import get_logger


logger = get_logger(__name__)


def remove_prefix(text: str, prefix: str):
    if text.startswith(prefix):
        return text[len(prefix) :]
    return text  # or whatever


def sanitize(sd):
    return {remove_prefix(k, "model."): v for k, v in sd.items()}


def average_state_dicts(state_dicts: List[Dict[str, torch.Tensor]]):
    new_sd = {}
    for k in state_dicts[0].keys():
        tensors = [sd[k] for sd in state_dicts]
        new_t = sum(tensors) / len(tensors)
        assert isinstance(new_t, torch.Tensor)
        new_sd[k] = new_t
    return new_sd


def convert_pl_to_hf(pl_ckpt_path: str, hf_src_model_dir: str, save_path: str) -> None:
    """Cleanup a pytorch-lightning .ckpt file or experiment dir and save a huggingface model with that state dict.
    Silently allows extra pl keys (like teacher.) Puts all ckpt models into CPU RAM at once!

    Args:
        pl_ckpt_path (:obj:`str`): Path to a .ckpt file saved by pytorch_lightning or dir containing ckpt files.
            If a directory is passed, all .ckpt files inside it will be averaged!
        hf_src_model_dir (:obj:`str`): Path to a directory containing a correctly shaped checkpoint
        save_path (:obj:`str`): Directory to save the new model

    """
    hf_model = AutoModelForSeq2SeqLM.from_pretrained(hf_src_model_dir)
    if os.path.isfile(pl_ckpt_path):
        ckpt_files = [pl_ckpt_path]
    else:
        assert os.path.isdir(pl_ckpt_path)
        ckpt_files = list(Path(pl_ckpt_path).glob("*.ckpt"))
        assert (
            ckpt_files
        ), f"could not find any ckpt files inside the {pl_ckpt_path} directory"

    if len(ckpt_files) > 1:
        logger.info(f"averaging the weights of {ckpt_files}")

    state_dicts = [
        sanitize(torch.load(x, map_location="cpu")["state_dict"]) for x in ckpt_files
    ]
    state_dict = average_state_dicts(state_dicts)

    missing, unexpected = hf_model.load_state_dict(state_dict, strict=False)
    assert not missing, f"missing keys: {missing}"
    hf_model.save_pretrained(save_path)
    try:
        tok = AutoTokenizer.from_pretrained(hf_src_model_dir)
        tok.save_pretrained(save_path)
    except Exception:
        pass
        # dont copy tokenizer if cant


if __name__ == "__main__":
    fire.Fire(convert_pl_to_hf)
