# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
from typing import List, Literal, Optional, Union

import aiohttp
from modelscope.hub.utils.utils import get_cache_dir

from swift.utils import get_logger, safe_ddp_context

logger = get_logger()


class MediaResource:
    """A class to manage the resource downloading."""

    cache_dir = os.path.join(get_cache_dir(), 'media_resources')
    lock_dir = os.path.join(get_cache_dir(), 'lockers')

    media_type_urls = {
        'llava', 'coco', 'sam', 'gqa', 'ocr_vqa', 'textvqa', 'VG_100K', 'VG_100K_2', 'share_textvqa', 'web-celebrity',
        'web-landmark', 'wikiart'
    }

    URL_PREFIX = 'https://www.modelscope.cn/api/v1/datasets/hjh0119/sharegpt4v-images/repo?Revision=master&FilePath='

    @staticmethod
    def get_url(media_type):
        is_ocr_vqa = (media_type == 'ocr_vqa')
        extension = 'tar' if is_ocr_vqa else 'zip'
        return f'{MediaResource.URL_PREFIX}{media_type}.{extension}'

    @staticmethod
    def download(media_type_or_url: Union[str, List[str]],
                 local_alias: Optional[str] = None,
                 file_type: Literal['compressed', 'file', 'sharded'] = 'compressed'):
        """Download and extract a resource from a http link.

        Args:
            media_type_or_url: `str` or List or `str`, Either belongs to the `media_type_urls`
                listed in the class field, or a remote url to download and extract.
                Be aware that, this media type or url needs to contain a zip or tar file.
            local_alias: `Options[str]`, The local alias name for the `media_type_or_url`. If the first arg is a
                media_type listed in this class, local_alias can leave None. else please pass in a name for the url.
                The local dir contains the extracted files will be: {cache_dir}/{local_alias}
            file_type: The file type, if is a compressed file, un-compressed the file,
                if is an original file, only download it, if is a sharded file, download all files and extract.

        Returns:
            The local dir contains the extracted files.
        """
        media_file = media_type_or_url if isinstance(media_type_or_url, str) else media_type_or_url[0]
        with safe_ddp_context(hash_id=media_file):
            return MediaResource._safe_download(
                media_type=media_type_or_url, media_name=local_alias, file_type=file_type)

    @staticmethod
    def move_directory_contents(src_dir, dst_dir):
        if not os.path.exists(dst_dir):
            os.makedirs(dst_dir)

        for dirpath, dirnames, filenames in os.walk(src_dir):
            relative_path = os.path.relpath(dirpath, src_dir)
            target_dir = os.path.join(dst_dir, relative_path)

            if not os.path.exists(target_dir):
                os.makedirs(target_dir)

            for file in filenames:
                src_file = os.path.join(dirpath, file)
                dst_file = os.path.join(target_dir, file)
                shutil.move(src_file, dst_file)

    @staticmethod
    def _safe_download(media_type: Union[str, List[str]],
                       media_name: Optional[str] = None,
                       file_type: Literal['compressed', 'file', 'sharded'] = 'compressed'):
        media_name = media_name or media_type
        assert isinstance(media_name, str), f'{media_name} is not a str'
        if isinstance(media_type, str) and media_type in MediaResource.media_type_urls:
            media_type = MediaResource.get_url(media_type)

        from datasets.download.download_manager import DownloadManager, DownloadConfig
        final_folder = os.path.join(MediaResource.cache_dir, media_name)

        if file_type == 'file':
            filename = media_type.split('/')[-1]
            final_path = os.path.join(final_folder, filename)
            if os.path.exists(final_path):  # if the download thing is a file but not folder,
                return final_folder  # check whether the file exists
            if not os.path.exists(final_folder):
                os.makedirs(final_folder)  # and make sure final_folder exists to contain it
        else:
            if os.path.exists(final_folder):
                return final_folder

        logger.info('# #################Resource downloading#################')
        logger.info('Downloading necessary resources...')
        logger.info(f'Resource package: {media_type}')
        logger.info(f'Extracting to local dir: {final_folder}')
        logger.info('If the downloading fails or lasts a long time, '
                    'you can manually download the resources and extracting to the local dir.')
        logger.info('Now begin.')
        download_config = DownloadConfig(cache_dir=MediaResource.cache_dir)
        download_config.storage_options = {'client_kwargs': {'timeout': aiohttp.ClientTimeout(total=86400)}}
        if file_type == 'file':
            filename = media_type.split('/')[-1]
            final_path = os.path.join(final_folder, filename)
            local_dirs = DownloadManager(download_config=download_config).download(media_type)
            shutil.move(str(local_dirs), final_path)
        elif file_type == 'compressed':
            local_dirs = DownloadManager(download_config=download_config).download_and_extract(media_type)
            shutil.move(str(local_dirs), final_folder)
        else:
            for media_url in media_type:
                local_dirs = DownloadManager(download_config=download_config).download_and_extract(media_url)
                MediaResource.move_directory_contents(str(local_dirs), final_folder)
        logger.info('# #################Resource downloading finished#################')
        return final_folder

    @staticmethod
    def safe_save(image, file_name, folder, format='JPEG'):
        folder = os.path.join(MediaResource.cache_dir, folder)
        os.makedirs(folder, exist_ok=True)
        file = os.path.join(folder, file_name)
        if os.path.exists(file):
            return file
        image.save(file, format=format)
        return file
