import shutil
import os
from os.path import join, realpath, expanduser, basename, dirname, isdir, isfile


class LogWriter(object):

    def __getattr__(self, k):
        return print

log = LogWriter()


_PATHS = {
    # name, default_value
    'DATA_ROOT': '~/datasets' if 'LOCAL_TMPDIR' not in os.environ else os.environ['LOCAL_TMPDIR'],
    'DATA_REPO_ROOT': '~/dataset_repository',  # not necessary when S3 is used
    
    # TODO: feature_cache folder in the tracking repository, could be discussed.
    'FEATURE_CACHE_ROOT': realpath(join(dirname(__file__), '..', 'feature_cache')),
}

def get_path(name):
    if name not in _PATHS:
        raise KeyError(f'{name} is not a valid path name.')

    if name in os.environ:
        return os.environ[name]
    else:
        return os.path.expanduser(_PATHS[name])
    
def get_dataset_path(dataset_name):
    return join(get_path('DATA_ROOT'), dataset_name)


def get_s3_connection():
    try:
        import boto3
    except ImportError as e:
        print(e)
        print('you need the boto3 library to work with s3')

    session = boto3.session.Session()

    try:
        access_key = open(os.path.expanduser('~/.s3/access_key'), 'r').read().split('\n')[0]
        secret_key = open(os.path.expanduser('~/.s3/secret_key'), 'r').read().split('\n')[0]
    except FileNotFoundError as e:
        print()

    s3_client = session.client(
        service_name='s3',
        aws_access_key_id=access_key,
        aws_secret_access_key=secret_key,
        endpoint_url='https://s3.gwdg.de',
    )
    return s3_client


def get_object(s3_key):
    s3 = get_s3_connection()
    response = s3.get_object(Bucket='eckerlab-datasets', Key=s3_key)    
    return response["Body"].read()


def get_file_from_s3(filename, folder, local_dir=None):
    local_dir = get_path('DATA_ROOT') if local_dir is None else local_dir
    s3 = get_s3_connection()

    filename_ = join(local_dir, folder,  filename)

    if not os.path.isfile(filename_):
        os.makedirs(local_dir, exist_ok=True)

        if not os.path.isfile(filename):
            s3_key =  filename
            print(f'download {s3_key} from s3 to {filename_}')
            s3.download_file('eckerlab-datasets', s3_key, filename_)

    else:
        return None


def get_archive_from_s3(name, local_dir=None, file_type='tar.gz'):
    """ same as get_from_s3, without explicit chunk"""

    local_dir = get_path('DATA_ROOT') if local_dir is None else local_dir
    s3 = get_s3_connection()

    filename = join(local_dir, name, '__tmp.' + file_type)

    # check if the folder exists
    if not os.path.isdir(dirname(filename)):
        os.makedirs(join(local_dir, name), exist_ok=True)

        if not os.path.isfile(filename):

            s3_key =  f'{name}.{file_type}'
            repo_file = join(get_path('DATA_REPO_ROOT'), s3_key)
            if isfile(repo_file):
                # this should be faster than S3
                shutil.copy2(repo_file, filename)
                print('copy from local repo')
            else:
                print(f'download {s3_key} from s3 to {filename}')
                s3.download_file('eckerlab-datasets', s3_key, filename)
        extract_archive(filename, dirname(filename), delete_archive=True)
        return f'folder {filename} created'
    else:
        return None




def get_from_s3(name, chunk, local_dir=None, file_type='tar.gz', noarchive_ok=True):
    local_dir = get_path('DATA_ROOT') if local_dir is None else local_dir
    s3 = get_s3_connection()

    filename = join(local_dir, name, str(chunk), '__tmp.' + file_type)

    if not os.path.isdir(dirname(filename)):
        os.makedirs(join(local_dir, name, str(chunk)), exist_ok=True)

        if not os.path.isfile(filename):
            s3_key =  f'{name}_{chunk}.{file_type}'
            repo_file = join(get_path('DATA_REPO_ROOT'), s3_key)
            if isfile(repo_file):
                # this should be faster than S3
                shutil.copy2(repo_file, filename)
                print('copy from local repo')
            else:
                print(f'download {s3_key} from s3 to {filename}')
                s3.download_file('eckerlab-datasets', s3_key, filename)
        extract_archive(filename, dirname(filename), noarchive_ok=noarchive_ok, delete_archive=True)
        return f'folder {filename} created'
    else:
        return None


def from_s3(name, local_dir=None, file_type='tar.gz', noarchive_ok=True):
    """ same as get_from_s3, without explicit chunk"""

    local_dir = get_path('DATA_ROOT') if local_dir is None else local_dir
    s3 = get_s3_connection()

    filename = join(local_dir, name, '__tmp.' + file_type)

    if file_type not in {'tar.gt', 'tar', 'zip', 'gz', 'tgz'}:
        s3_key =  f'{name}.{file_type}'
        target_file = join(local_dir, s3_key)
        if not os.path.isfile(target_file):
            s3.download_file('eckerlab-datasets', s3_key, target_file)
        else:
            print(f'file {target_file} already exists.')
    else:
        # check if the folder exists
        if not os.path.isdir(dirname(filename)):
            os.makedirs(join(local_dir, name), exist_ok=True)

            if not os.path.isfile(filename):

                s3_key =  f'{name}.{file_type}'
                repo_file = join(get_path('DATA_REPO_ROOT'), s3_key)
                if isfile(repo_file):
                    # this should be faster than S3
                    shutil.copy2(repo_file, filename)
                    print('copy from local repo')
                else:
                    print(f'download {s3_key} from s3 to {filename}')
                    s3.download_file('eckerlab-datasets', s3_key, filename)
            extract_archive(filename, dirname(filename), noarchive_ok=noarchive_ok, delete_archive=True)

            return f'folder {filename} created'
        else:
            return None




def extract_archive(filename, target_folder=None, noarchive_ok=False, delete_archive=False):
    from subprocess import run, PIPE
    import re
    filename.split()

    if re.match(r'^(.*)\.(tar|tgz)$', filename):
        command = f'tar -xf {filename}'
        command += f' -C {target_folder}' if target_folder is not None else ''
    elif re.match(r'^(.*)\.tar\.gz$', filename):
        command = f'tar -xzf {filename}'
        command += f' -C {target_folder}' if target_folder is not None else ''
    elif re.match(r'^(.*)\.zip$', filename):
        command = f'unzip {filename}'
        command += f' -d {target_folder}' if target_folder is not None else ''    
    else:
        if noarchive_ok:
            file_type = filename.split('.')[-1]
            shutil.move(filename, f'{dirname(filename)}.{file_type}')
            if delete_archive:
                print('delete', dirname(filename))
                os.removedirs(dirname(filename))
            return 
        else:
            raise ValueError(f'unsuppored file ending of {filename}')

    log.hint(command)
    result = run(command.split(), stdout=PIPE, stderr=PIPE)
    if result.returncode != 0:
        print(result.stdout, result.stderr)

    if delete_archive:
        os.remove(filename)
    return [filename]

