### SIGMA-conf: the lone wolf configuration manager.
###
import os
import pickle

import yaml

from typing import Dict, Union, List, Callable
from functools import reduce

import re
import datetime
import importlib
from functools import partial
from itertools import zip_longest

import sys
_this_module_path = os.path.dirname(__file__)
# sys.path.append(_this_module_path)
print(_this_module_path)

# root config dir is configurable 
ROOT_CONFIG_DIR = os.environ['ROOT_CONFIG_DIR'] if 'ROOT_CONFIG_DIR' in os.environ else os.path.join(os.path.dirname(os.path.realpath(__file__)), 'config')
try:
    import git
    git_repo = git.Repo(os.path.dirname(_this_module_path), search_parent_directories=True) # look in the parent dir from here
    ROOT_REPO_DIR = git_repo.git.rev_parse("--show-toplevel")
    GIT_CURRENT_HEAD_SHA = git_repo.head.object.hexsha
    print(f'Sigma conf enabled: {_this_module_path} {ROOT_REPO_DIR} {GIT_CURRENT_HEAD_SHA}')
except Exception as e:
    print(f'Could not do the modules thing, possibly needs to install gitpython: {str(e)}')
    GIT_CURRENT_HEAD_SHA = 'N/A'

_DEBUG_MODE = False

def print_debug(*args, **kwargs):
    if _DEBUG_MODE:
        print(*args, **kwargs)


def read_config(
    name: str,
    subject: str = "",
    abspath = False
):

    fpath = os.path.abspath(os.path.join(ROOT_CONFIG_DIR, subject, name+'.yaml')) if not abspath else os.path.abspath(name)

    with open(fpath, 'rt') as f:
        config = yaml.load(f, yaml.SafeLoader)

    return config


def load_master_config(
    name: str,
    abspath = False,
    return_original_conf = False,
):
    '''
    :param name:
        name of the configuration file
    :param abspath:
        whether the name is an absolute path
    :return:
    master config fully assembled
    '''
    master_cfg = read_config(name, '', abspath)

    _, master_cfg_fullsub, master_cfg_namesub = config_sub_recursive(0, master_cfg)

    # return substituted version for function and original version for tracking
    if not return_original_conf:
        return master_cfg_fullsub, master_cfg_namesub
    else:
        return master_cfg_fullsub, master_cfg_namesub, master_cfg

# instructions for csr: '_pb.' - push back; '_skp.' - do not substitute the branch
# if both need to be used, use _pb._skp.
pushback_prefix = '_pb.'
skip_prefix = '_skp.'

def config_sub_recursive(key, node):
    '''
    Apply the substitution function.
    Keep track of raw substituted values and of the fully subbed stuff.
    '''
    if isinstance(key, str) and key.startswith(skip_prefix):
        # remove the skip prefix
        key = key[len(skip_prefix):]
        raw_buffer = node
        buffer = node
    else:
        if isinstance(node, List):
            buffer = []
            raw_buffer = []
            for i, v in enumerate(node):
                rei, rev, raw_rev = config_sub_recursive(i, v)
                buffer.append(rev)
                raw_buffer.append(raw_rev)
        elif isinstance(node, Dict):
            buffer = {}
            raw_buffer = {}
            for k, v in node.items():
                if isinstance(k, str) and k.startswith(pushback_prefix):
                    k = k[len(pushback_prefix):]
                    rek, rev, raw_rev = config_sub_recursive(k, v)
                    buffer.update(rev)
                    raw_buffer.update(raw_rev)
                else:
                    rek, rev, raw_rev = config_sub_recursive(k, v)
                    buffer[rek] = rev
                    raw_buffer[rek] = raw_rev
        else:
            # if leaf node, substitute the value
            buffer = run_scommand_inline(node)
            raw_buffer = buffer


    # now can run the possible key command (raw_buffer is unaffected)
    rekey = run_scommand_inline(key)
    need_recurse, rekey, buffer = run_scommand_outline(rekey, buffer)

    if need_recurse:
        # if the outline command suggest going over the stuff again, go in (e.g. when include is used)
        # can be used by scommands operating on the skipped branches
        rekey, buffer, raw_buffer = config_sub_recursive(rekey, buffer)

    # return the new key and the new value
    return rekey, buffer, raw_buffer


search_command = "\${!(?:[^${}]|)+}"
search_re_command = re.compile(search_command)


def run_scommand_inline(
        maybe_expression: str,
        max_depth: int = 3,
    ):
    # first run everything inline
    for i in range(max_depth):
        # print(maybe_expression)
        if not isinstance(maybe_expression, str):
            return maybe_expression  
        stuff = search_re_command.findall(maybe_expression)
        print_debug(stuff)
        around = search_re_command.split(maybe_expression)
        if len(stuff) == 0:
            return yaml.safe_load(maybe_expression)
        # print(stuff, around)
        # fullmatch = search_o_re_command.fullmatch(maybe_expression)
        # # avoid matching outline conditions
        # if fullmatch is not None or len(stuff) == 0:
        #     return maybe_expression
        print_debug('around, stuff: ', around, stuff)
        restuff = []
        for s in stuff:
            # get the command part
            s = s[3:-1]
            # split into the command and arguments
            args = s.split(' ')
            command = args.pop(0)
            # interpret args
            args = args # [yaml.safe_load(a) for a in args]
            # run the local command and append to the new list
            restuff.append(inline_scommands[command](*args))
        print_debug('restuff: ', restuff)
        # print(restuff, around)
        all_pieces = reduce(lambda a,b: list(a)+list(b), list(zip_longest(around, restuff, fillvalue='')), [])
        print_debug('all_pieces: ', all_pieces)
        # print(all_pieces)
        # maybe_expression = yaml.safe_load(''.join(all_pieces))
        maybe_expression = ''.join(all_pieces)
            
    raise Exception("Max iteration exceeded when substituting configuration, KISS/KYS")


search_o_command = "\${-(?:[^${}]|)+}"
search_o_re_command = re.compile(search_o_command)


def run_scommand_outline(key, value):
    if not isinstance(key, str):
        return False, key, value
    fullmatch = search_o_re_command.fullmatch(key)
    if fullmatch is None:
        return False, key, value
    else:
        # get the command part
        s = key[3:-1]
        # split into the command and arguments
        args = s.split(' ')
        # in this case the new key is selected
        newkey = args.pop(0)
        command = args.pop(0)
        # interpret args
        args = [yaml.safe_load(a) for a in args]
        # run the local command and append to the new list
        need_recurse, updated_key, newval = larger_scommands[command](*args, **value, __sc_key=newkey)
        if updated_key is not None:
            newkey = updated_key
        return need_recurse, newkey, newval


def scommand_construct(module: str, cls_: str, *args, __sc_key=None, **kwargs):
    '''
    An scommand to build a module based on config.
    module: module path relative to this file or absolute
    cls_: name of class/function/variable to fetch 
    args[0] = return_instance: if True, instatntiates the object/evaluates the function, otherwise returns the protype
    '''
    if len(args)>0:
        return_instance = args[0] #yaml.safe_load(args[0])
    else:
        return_instance = True
    package = '..' if module.startswith('.') else None
    module = module[1:] if module.startswith('.') else module
    a = importlib.import_module(module, package=package)
    cls_ = getattr(a, cls_)
    return False, None, cls_(**kwargs) if return_instance else cls_


def scommand_partial(module: str, cls_: str, *args, __sc_key=None, **kwargs):
    '''
    An scommand to build a module based on config.
    module: module path relative to this file or absolute
    cls_: name of class/function/variable to fetch 
    '''
    package = '..' if module.startswith('.') else None
    module = module[1:] if module.startswith('.') else module
    a = importlib.import_module(module, package=package)
    cls_ = getattr(a, cls_)
    fun = partial(cls_, *args, **kwargs)
    return False, None, fun


def scommand_switch(var:str, *args, __sc_key=None, **kwargs):
    # chose given an environment variable, assume kwargs is options
    # disallow recurring into them
    env_val = scommand_getenv(var, *args)
    if env_val not in kwargs:
        env_val = args[0] # if specified, but not available, use default
    return True, env_val, kwargs[env_val]


def scommand_cache(cachepath: str, __sc_key=None, **kwargs):
    # TODO: implement
    raise NotImplementedError('caching not implemented (yet) in sigma conf!')
    return False, None, kwargs


def scommand_include(name: str, *args, __sc_key=None):
    '''
    An scommand to include another yaml file. Instructs the parser to recur into it.
    '''
    return True, None, read_config(name=name, abspath=False if len(args)==0 else yaml.safe_load(args[0]))


def scommand_getenv(env: str, *args):
    try:
        retval = os.environ[env] if env in os.environ else args[0]
        os.environ[f'CONFIG_USED_{env}'] = retval
    except Exception as e:
        print(str(e))
        print(env, args) 
        raise e   
    return retval


## IMPORTANT FOR COLLECTING ENV VARIABLES USED IN A RUN!
def getenv_roundup(clear=False):
    used_env_vars = {k[12:]: v for k,v in os.environ.items() if k.startswith('CONFIG_USED_')}
    if clear:
        for k in used_env_vars.keys():
            os.environ.pop('CONFIG_USED_'+k)
    return used_env_vars


def scommand_setenv(env: str, value: str):
    os.environ[env] = str(value)
    return ''

def scommand_delenv(env: str):
    del(os.environ[env])
    return ''

def scommand_prompt(prompt: str):
    return input(prompt)


def scommand_gitid():
    return str(GIT_CURRENT_HEAD_SHA)


def scommand_timestamp(*args):
    if len(args) > 0 and isinstance(args[0], str) and len(args[0])>0:
        return timestamp_str(args[0])
    else:
        return timestamp_str()

def scommand_launchtimestamp():
    return launch_time

def scommand_range(start, end):
    return str(list(range(int(start), int(end))))

from secrets import token_hex
def scommand_randtoken(length: str):
    length = int(length)
    return token_hex(length)


def scommand_ifdef(condenv: str, arg_a: str, arg_b: str):
    if condenv in os.environ:
        return arg_a
    else: 
        return arg_b


def scommand_ifeq(condenv: str, value: str, string: str, *alternative):
    if os.environ[condenv] == value:
        return string
    else:
        return '' if len(alternative)==0 else alternative[0]


def _flatten_dict(subj: dict, depth: int, stack_names=False):
    assert isinstance(subj, dict), "Trying to flatten non dictionary is not supported"
    resubj = {}
    for k_top, v_top in subj:
        assert isinstance(v_top, dict), "Trying to flatten non dictionary is not supported"
        for k, v in v_top:
            nuname = '_'.join(k_top, k) if stack_names else k
            resubj[nuname] = v
    if depth == 0:
        return resubj
    else: 
        return _flatten_dict(resubj, depth-1, stack_names=stack_names)


def scommand_flatten(depth: int, *args, **kwargs):
    ''' Use this scommand to flatten the hierarchy, if for 
    whatever reason we have more nesting than desired due
    to large number of outline scommands stacked.
    
    '''
    if len(args) > 0:
        concatenate_names = args[0]
    else:
        concatenate_names = False
    
    # kwargs is the dict to be flattened n times
    return False, _flatten_dict(kwargs, depth, stack_names=concatenate_names)


inline_scommands = {
    'gitid': scommand_gitid,
    'getenv': scommand_getenv,
    'setenv': scommand_setenv,
    'delenv': scommand_delenv,
    'prompt': scommand_prompt,
    'starttime': scommand_launchtimestamp,
    'nowtime': scommand_timestamp,
    'range': scommand_range,
    'ifdef': scommand_ifdef,
    'ifeq': scommand_ifeq,
}

larger_scommands = {
    'switch': scommand_switch,
    'construct': scommand_construct,
    'partial': scommand_partial,
    'cache': scommand_cache, 
    'include': scommand_include,  
    'dflatten': scommand_flatten,
}


def timestamp_str(ft = "%m_%dT%H_%M_%S"):
    tz = datetime.timezone.utc
    t = datetime.datetime.now(tz=tz).strftime(ft)
    return t

launch_time = timestamp_str()