# Copyright (c) Alibaba, Inc. and its affiliates.

import concurrent.futures
import os
import shutil
from multiprocessing import Manager, Process, Value

from swift.utils.logger import get_logger
from .api import HubApi
from .constants import DEFAULT_REPOSITORY_REVISION, ModelVisibility

logger = get_logger()

_executor = concurrent.futures.ProcessPoolExecutor(max_workers=8)
_queues = dict()
_flags = dict()
_tasks = dict()
_manager = None


def _api_push_to_hub(repo_name,
                     output_dir,
                     token,
                     private=True,
                     commit_message='',
                     tag=None,
                     source_repo='',
                     ignore_file_pattern=None,
                     revision=DEFAULT_REPOSITORY_REVISION):
    try:
        api = HubApi()
        api.login(token)
        api.push_model(
            repo_name,
            output_dir,
            visibility=ModelVisibility.PUBLIC
            if not private else ModelVisibility.PRIVATE,
            chinese_name=repo_name,
            commit_message=commit_message,
            tag=tag,
            original_model_id=source_repo,
            ignore_file_pattern=ignore_file_pattern,
            revision=revision)
        commit_message = commit_message or 'No commit message'
        logger.info(
            f'Successfully upload the model to {repo_name} with message: {commit_message}'
        )
        return True
    except Exception as e:
        logger.error(
            f'Error happens when uploading model {repo_name} with message: {commit_message}: {e}'
        )
        return False


def push_to_hub(repo_name,
                output_dir,
                token=None,
                private=True,
                retry=3,
                commit_message='',
                tag=None,
                source_repo='',
                ignore_file_pattern=None,
                revision=DEFAULT_REPOSITORY_REVISION):
    """
    Args:
        repo_name: The repo name for the modelhub repo
        output_dir: The local output_dir for the checkpoint
        token: The user api token, function will check the `MODELSCOPE_API_TOKEN` variable if this argument is None
        private: If is a private repo, default True
        retry: Retry times if something error in uploading, default 3
        commit_message: The commit message
        tag: The tag of this commit
        source_repo: The source repo (model id) which this model comes from
        ignore_file_pattern: The file pattern to be ignored in uploading.
        revision: The branch to commit to
    Returns:
        The boolean value to represent whether the model is uploaded.
    """
    if token is None:
        token = os.environ.get('MODELSCOPE_API_TOKEN')
    if ignore_file_pattern is None:
        ignore_file_pattern = os.environ.get('UPLOAD_IGNORE_FILE_PATTERN')
    assert repo_name is not None
    assert token is not None, 'Either pass in a token or to set `MODELSCOPE_API_TOKEN` in the environment variables.'
    assert os.path.isdir(output_dir)
    assert 'configuration.json' in os.listdir(output_dir) or 'configuration.yaml' in os.listdir(output_dir) \
           or 'configuration.yml' in os.listdir(output_dir)

    logger.info(
        f'Uploading {output_dir} to {repo_name} with message {commit_message}')
    for i in range(retry):
        if _api_push_to_hub(repo_name, output_dir, token, private,
                            commit_message, tag, source_repo,
                            ignore_file_pattern, revision):
            return True
    return False


def push_to_hub_async(repo_name,
                      output_dir,
                      token=None,
                      private=True,
                      commit_message='',
                      tag=None,
                      source_repo='',
                      ignore_file_pattern=None,
                      revision=DEFAULT_REPOSITORY_REVISION):
    """
    Args:
        repo_name: The repo name for the modelhub repo
        output_dir: The local output_dir for the checkpoint
        token: The user api token, function will check the `MODELSCOPE_API_TOKEN` variable if this argument is None
        private: If is a private repo, default True
        commit_message: The commit message
        tag: The tag of this commit
        source_repo: The source repo (model id) which this model comes from
        ignore_file_pattern: The file pattern to be ignored in uploading
        revision: The branch to commit to
    Returns:
        A handler to check the result and the status
    """
    if token is None:
        token = os.environ.get('MODELSCOPE_API_TOKEN')
    if ignore_file_pattern is None:
        ignore_file_pattern = os.environ.get('UPLOAD_IGNORE_FILE_PATTERN')
    assert repo_name is not None
    assert token is not None, 'Either pass in a token or to set `MODELSCOPE_API_TOKEN` in the environment variables.'
    assert os.path.isdir(output_dir)
    assert 'configuration.json' in os.listdir(output_dir) or 'configuration.yaml' in os.listdir(output_dir) \
           or 'configuration.yml' in os.listdir(output_dir)

    logger.info(
        f'Uploading {output_dir} to {repo_name} with message {commit_message}')
    return _executor.submit(_api_push_to_hub, repo_name, output_dir, token,
                            private, commit_message, tag, source_repo,
                            ignore_file_pattern, revision)


def submit_task(q, b):
    while True:
        b.value = False
        item = q.get()
        logger.info(item)
        b.value = True
        if not item.pop('done', False):
            delete_dir = item.pop('delete_dir', False)
            output_dir = item.get('output_dir')
            try:
                push_to_hub(**item)
                if delete_dir and os.path.exists(output_dir):
                    shutil.rmtree(output_dir)
            except Exception as e:
                logger.error(e)
        else:
            break


class UploadStrategy:
    cancel = 'cancel'
    wait = 'wait'


def push_to_hub_in_queue(queue_name, strategy=UploadStrategy.cancel, **kwargs):
    assert queue_name is not None and len(
        queue_name) > 0, 'Please specify a valid queue name!'
    global _manager
    if _manager is None:
        _manager = Manager()
    if queue_name not in _queues:
        _queues[queue_name] = _manager.Queue()
        _flags[queue_name] = Value('b', False)
        process = Process(
            target=submit_task, args=(_queues[queue_name], _flags[queue_name]))
        process.start()
        _tasks[queue_name] = process

    queue = _queues[queue_name]
    flag: Value = _flags[queue_name]
    if kwargs.get('done', False):
        queue.put(kwargs)
    elif flag.value and strategy == UploadStrategy.cancel:
        logger.error(
            f'Another uploading is running, '
            f'this uploading with message {kwargs.get("commit_message")} will be canceled.'
        )
    else:
        queue.put(kwargs)


def wait_for_done(queue_name):
    process: Process = _tasks.pop(queue_name, None)
    if process is None:
        return
    process.join()

    _queues.pop(queue_name)
    _flags.pop(queue_name)
