"""
Adapted from https://github.com/mlfoundations/open_clip. Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman, Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar, John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi, Ludwig Schmidt
"""

import argparse
import json
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Tuple, Union

import torch

try:
    from huggingface_hub import (
        create_repo,
        get_hf_file_metadata,
        hf_hub_download,
        hf_hub_url,
        repo_type_and_id_from_hf_id,
        upload_folder,
        list_repo_files,
    )
    from huggingface_hub.utils import EntryNotFoundError
    _has_hf_hub = True
except ImportError:
    _has_hf_hub = False

try:
    import safetensors.torch
    _has_safetensors = True
except ImportError:
    _has_safetensors = False

from .factory import create_model_from_pretrained, get_model_config, get_tokenizer
from .tokenizer import HFTokenizer

# Default name for a weights file hosted on the Huggingface Hub.
HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin"  # default pytorch pkl
HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors"  # safetensors version
HF_CONFIG_NAME = 'open_clip_config.json'

def save_config_for_hf(
        model,
        config_path: str,
        model_config: Optional[dict]
):
    preprocess_cfg = {
        'mean': model.visual.image_mean,
        'std': model.visual.image_std,
    }
    hf_config = {
        'model_cfg': model_config,
        'preprocess_cfg': preprocess_cfg,
    }

    with config_path.open('w') as f:
        json.dump(hf_config, f, indent=2)


def save_for_hf(
    model,
    tokenizer: HFTokenizer,
    model_config: dict,
    save_directory: str,
    safe_serialization: Union[bool, str] = False,
    skip_weights : bool = False,
):
    config_filename = HF_CONFIG_NAME

    save_directory = Path(save_directory)
    save_directory.mkdir(exist_ok=True, parents=True)

    if not skip_weights:
        tensors = model.state_dict()
        if safe_serialization is True or safe_serialization == "both":
            assert _has_safetensors, "`pip install safetensors` to use .safetensors"
            safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
        if safe_serialization is False or safe_serialization == "both":
            torch.save(tensors, save_directory / HF_WEIGHTS_NAME)

    tokenizer.save_pretrained(save_directory)

    config_path = save_directory / config_filename
    save_config_for_hf(model, config_path, model_config=model_config)


def push_to_hf_hub(
    model,
    tokenizer,
    model_config: Optional[dict],
    repo_id: str,
    commit_message: str = 'Add model',
    token: Optional[str] = None,
    revision: Optional[str] = None,
    private: bool = False,
    create_pr: bool = False,
    model_card: Optional[dict] = None,
    safe_serialization: Union[bool, str] = False,
):
    if not isinstance(tokenizer, HFTokenizer):
        # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14
        tokenizer = HFTokenizer('openai/clip-vit-large-patch14')

    # Create repo if it doesn't exist yet
    repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)

    # Infer complete repo_id from repo_url
    # Can be different from the input `repo_id` if repo_owner was implicit
    _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
    repo_id = f"{repo_owner}/{repo_name}"

    # Check if repo already exists and determine what needs updating
    repo_exists = False
    repo_files = {}
    try:
        repo_files = set(list_repo_files(repo_id))
        repo_exists = True
    except Exception as e:
        print('Repo does not exist', e)

    try:
        get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
        has_readme = True
    except EntryNotFoundError:
        has_readme = False

    # Dump model and push to Hub
    with TemporaryDirectory() as tmpdir:
        # Save model weights and config.
        save_for_hf(
            model,
            tokenizer=tokenizer,
            model_config=model_config,
            save_directory=tmpdir,
            safe_serialization=safe_serialization,
        )

        # Add readme if it does not exist
        if not has_readme:
            model_card = model_card or {}
            model_name = repo_id.split('/')[-1]
            readme_path = Path(tmpdir) / "README.md"
            readme_text = generate_readme(model_card, model_name)
            readme_path.write_text(readme_text)

        # Upload model and return
        return upload_folder(
            repo_id=repo_id,
            folder_path=tmpdir,
            revision=revision,
            create_pr=create_pr,
            commit_message=commit_message,
        )


def push_pretrained_to_hf_hub(
    model_name,
    pretrained: str,
    repo_id: str,
    precision: str = 'fp32',
    image_mean: Optional[Tuple[float, ...]] = None,
    image_std: Optional[Tuple[float, ...]] = None,
    commit_message: str = 'Add model',
    token: Optional[str] = None,
    revision: Optional[str] = None,
    private: bool = False,
    create_pr: bool = False,
    model_card: Optional[dict] = None,
):
    model, preprocess_eval = create_model_from_pretrained(
        model_name,
        pretrained=pretrained,
        precision=precision,
        image_mean=image_mean,
        image_std=image_std,
    )

    model_config = get_model_config(model_name)
    assert model_config

    tokenizer = get_tokenizer(model_name)

    push_to_hf_hub(
        model=model,
        tokenizer=tokenizer,
        model_config=model_config,
        repo_id=repo_id,
        commit_message=commit_message,
        token=token,
        revision=revision,
        private=private,
        create_pr=create_pr,
        model_card=model_card,
        safe_serialization='both',
    )


def generate_readme(model_card: dict, model_name: str):
    readme_text = "---\n"
    readme_text += "tags:\n- clip\n"
    readme_text += "library_name: open_clip\n"
    readme_text += "pipeline_tag: zero-shot-image-classification\n"
    readme_text += f"license: {model_card.get('license', 'mit')}\n"
    if 'details' in model_card and 'Dataset' in model_card['details']:
        readme_text += 'datasets:\n'
        readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
    readme_text += "---\n"
    readme_text += f"# Model card for {model_name}\n"
    if 'description' in model_card:
        readme_text += f"\n{model_card['description']}\n"
    if 'details' in model_card:
        readme_text += f"\n## Model Details\n"
        for k, v in model_card['details'].items():
            if isinstance(v, (list, tuple)):
                readme_text += f"- **{k}:**\n"
                for vi in v:
                    readme_text += f"  - {vi}\n"
            elif isinstance(v, dict):
                readme_text += f"- **{k}:**\n"
                for ki, vi in v.items():
                    readme_text += f"  - {ki}: {vi}\n"
            else:
                readme_text += f"- **{k}:** {v}\n"
    if 'usage' in model_card:
        readme_text += f"\n## Model Usage\n"
        readme_text += model_card['usage']
        readme_text += '\n'

    if 'comparison' in model_card:
        readme_text += f"\n## Model Comparison\n"
        readme_text += model_card['comparison']
        readme_text += '\n'

    if 'citation' in model_card:
        readme_text += f"\n## Citation\n"
        if not isinstance(model_card['citation'], (list, tuple)):
            citations = [model_card['citation']]
        else:
            citations = model_card['citation']
        for c in citations:
            readme_text += f"```bibtex\n{c}\n```\n"

    return readme_text


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Push to Hugging Face Hub")
    parser.add_argument(
        "--model", type=str, help="Name of the model to use.",
    )
    parser.add_argument(
        "--pretrained", type=str,
        help="Use a pretrained CLIP model weights with the specified tag or file path.",
    )
    parser.add_argument(
        "--repo-id", type=str,
        help="Destination HF Hub repo-id ie 'organization/model_id'.",
    )
    parser.add_argument(
        "--precision", type=str, default='fp32',
    )
    parser.add_argument(
        '--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
        help='Override default image mean value of dataset')
    parser.add_argument(
        '--image-std', type=float, nargs='+', default=None, metavar='STD',
        help='Override default image std deviation of of dataset')
    args = parser.parse_args()

    print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}')

    # FIXME add support to pass model_card json / template from file via cmd line

    push_pretrained_to_hf_hub(
        args.model,
        args.pretrained,
        args.repo_id,
        precision=args.precision,
        image_mean=args.image_mean,  # override image mean/std if trained w/ non defaults
        image_std=args.image_std,
    )

    print(f'{args.model} saved.')
