import gc
import json
import shutil
import sys
from pathlib import Path

import torch

from litgpt import Config
from litgpt.utils import incremental_save, lazy_load
from convert_pretrained_checkpoint import convert_checkpoint
from convert_lit_checkpoint import check_conversion_supported, copy_weights_prefix_suffix_net

from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import create_repo

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

@torch.inference_mode()
def convert_lit_checkpoint(
    checkpoint_path: Path, output_path: Path, config_path: Path, axonn_patch: bool = False, model_type: str = "prefix"
) -> None:
    output_path.parent.mkdir(parents=True, exist_ok=True)

    copy_fn = copy_weights_prefix_suffix_net

    # initialize a new empty state dict to hold our new weights
    sd = {}
    with incremental_save(output_path) as saver:
        lit_weights = lazy_load(checkpoint_path)
        lit_weights = lit_weights.get("model", lit_weights)
        check_conversion_supported(lit_weights)
        copy_fn(sd, lit_weights, saver=saver, model_type=model_type)
        gc.collect()
        saver.save(sd)


@torch.inference_mode()
def convert_checkpoint_to_hf(
    checkpoint_file: Path,
    tokenizer_dir: Path,
    model_name: str,
    parent_dir: Path = None,
    axonn_patch: bool = False,
    push_to_hub: bool = True,
    model_type: str = "prefix",
) -> None:
    ### convert training checkpoint to lit checkpoint
    parent_dir = checkpoint_file.parent.absolute() if parent_dir is None else parent_dir
    with open(parent_dir / "model_config.json") as f:
        model_config = json.load(f)
    config_name = model_config["name"]
    convert_checkpoint(checkpoint_file, tokenizer_dir, config_name, parent_dir / f"lit_checkpoint_{model_type}_{model_name}")

    ### convert training checkpoint to hf checkpoint
    convert_lit_checkpoint(
        parent_dir / f"lit_checkpoint_{model_type}_{model_name}/lit_model.pth",
        parent_dir / f"hf_checkpoint_{model_type}_{model_name}/pytorch_model.bin",
        parent_dir / f"lit_checkpoint_{model_type}_{model_name}/lit_config.json",
        axonn_patch=axonn_patch,
        model_type=model_type,
    )

    for tokenizer_file in tokenizer_dir.glob("tokenizer*"):
        shutil.copyfile(tokenizer_file, parent_dir / f"hf_checkpoint_{model_type}_{model_name}" / tokenizer_file.name)

    if (tokenizer_dir / "generation_config.json").is_file():
        shutil.copyfile(
            tokenizer_dir / "generation_config.json",
            parent_dir / f"hf_checkpoint_{model_type}_{model_name}" / "generation_config.json",
        )

    if (tokenizer_dir / "special_tokens_map.json").is_file():
        shutil.copyfile(
            tokenizer_dir / "special_tokens_map.json",
            parent_dir / f"hf_checkpoint_{model_type}_{model_name}" / "special_tokens_map.json",
        )

    if (tokenizer_dir / "added_tokens.json").is_file():
        shutil.copyfile(
            tokenizer_dir / "added_tokens.json", parent_dir / f"hf_checkpoint_{model_type}_{model_name}" / "added_tokens.json"
        )

    if (tokenizer_dir / "config.json").is_file():
        shutil.copyfile(tokenizer_dir / "config.json", parent_dir / f"hf_checkpoint_{model_type}_{model_name}" / "config.json")

    hf_org = model_config["hf_config"]["org"]
    hf_name = model_config["hf_config"]["name"]
    hf_config = AutoConfig.from_pretrained(f"{hf_org}/{hf_name}")
    hf_config = hf_config.to_dict()
    with open(parent_dir / f"hf_checkpoint_{model_type}_{model_name}" / "config.json", "w") as f:
        json.dump(hf_config, f, indent=4)

    ### push to hub
    repo_name = f"XXXX-6/{model_name}"
    tokenizer = AutoTokenizer.from_pretrained(parent_dir / f"hf_checkpoint_{model_type}_{model_name}")
    state_dict = torch.load(parent_dir / f"hf_checkpoint_{model_type}_{model_name}/pytorch_model.bin")
    model = AutoModelForCausalLM.from_pretrained(parent_dir / f"hf_checkpoint_{model_type}_{model_name}", state_dict=state_dict)

    if not push_to_hub:
        return
    create_repo(repo_name, private=True)
    model.push_to_hub(repo_name, use_temp_dir=True)
    tokenizer.push_to_hub(repo_name, use_temp_dir=True)

    print(f"Model pushed to {repo_name}")


if __name__ == "__main__":
    # getting cli args
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint_file", type=str, required=True)
    parser.add_argument("--tokenizer_dir", type=str, required=True)
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--parent_dir", type=str, default=None)
    parser.add_argument("--axonn_patch", action="store_true")
    parser.add_argument("--push_to_hub", action="store_true")
    args = parser.parse_args()

    # converting and saving the prefix model
    convert_checkpoint_to_hf(
        Path(args.checkpoint_file),
        Path(args.tokenizer_dir),
        args.model_name,
        Path(args.parent_dir) if args.parent_dir else None,
        args.axonn_patch,
        args.push_to_hub,
        model_type="prefix",
    )

    # converting and saving the suffix model
    convert_checkpoint_to_hf(
        Path(args.checkpoint_file),
        Path(args.tokenizer_dir),
        args.model_name,
        Path(args.parent_dir) if args.parent_dir else None,
        args.axonn_patch,
        args.push_to_hub,
        model_type="suffix",
    )
    # CLI(convert_checkpoint_to_hf)
