import pandas as pd
import torch
import argparse
from transformers import AutoTokenizer


def salvage_lightning_checkpoint(lightning_checkpoint_path: str):
    ckpt = torch.load(lightning_checkpoint_path)
    args = ckpt["hyper_parameters"]
    args = argparse.Namespace(**args)
    additional_tokens_list = pd.read_csv(
        f"{args.data_dir}/special_tokens.txt", names=["special_tokens"]
    )["special_tokens"].tolist()
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    tokenizer.add_tokens(additional_tokens_list)
    model = SummarizationModule(args, tokenizer=tokenizer)
    model.model.resize_token_embeddings(len(model.tokenizer))
    model.load_state_dict(ckpt["state_dict"])
    model.on_save_checkpoint(save_path=args.default_root_dir + "/final_hf_checkpoint/")
    return model

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--pl_ckpt", type=str)
    args = parser.parse_args()
    salvage_lightning_checkpoint(args.pl_ckpt)