import functools
import re
from collections import namedtuple

import YYY
import os
import numpy as np


# https://codereview.stackexchange.com/questions/85311/transform-snake-case-to-camelcase
def camel_case_name(snake_case_name):
    return re.sub("_([a-z])", lambda match: match.group(1).upper(), snake_case_name)


__namedtuples = {}


def sanitize_keys(keys):
    return list(map(sanitize_key, keys))


def sanitize_key(key):
    return key.replace(".", "_")


def to_namedtuple(obj, name):
    type_name = "_" + camel_case_name(sanitize_key(name))
    if isinstance(obj, dict):
        keys = tuple(obj.keys())
        if all(isinstance(key, str) for key in keys):
            if keys in __namedtuples:
                nt = __namedtuples[keys]
            else:
                nt = namedtuple(type_name, sanitize_keys(keys))
                __namedtuples[keys] = nt
        else:
            return {key: to_namedtuple(value, str(key)) for key, value in obj.items()}
        return nt(*(to_namedtuple(v, k) for k, v in obj.items()))
    if isinstance(obj, list):
        item_type_name = type_name + "Item"
        return [to_namedtuple(item, item_type_name) for item in obj]
    if isinstance(obj, set):
        item_type_name = type_name + "Item"
        return {to_namedtuple(item, item_type_name) for item in obj}
    if isinstance(obj, tuple):
        item_type_name = type_name + "Item"
        return tuple(to_namedtuple(item, item_type_name) for item in obj)

    return obj


def get_any(d: dict):
    return next(iter(d.values()))


def handle_map_funcs(func_kv, func_k, func_v, default=None):
    if func_kv:
        assert func_k is None and func_v is None

        def inner(kv):
            return func_kv(*kv)

    elif func_k:
        assert func_v is None

        def inner(kv):
            return func_k(kv[0]), kv[1]

    elif func_v:

        def inner(kv):
            return kv[0], func_v(kv[1])

    else:
        return default
    return inner


def handle_unary_funcs(pred_kv, pred_k, pred_v, default=None):
    if pred_kv:
        assert pred_k is None and pred_v is None

        def inner(kv):
            return pred_kv(*kv)

    elif pred_k:
        assert pred_v is None

        def inner(kv):
            return pred_k(kv[0])

    elif pred_v:

        def inner(kv):
            return pred_v(kv[1])

    else:
        return default
    return inner


def map_dict(d: dict, *, kv=None, k=None, v=None):
    inner = handle_map_funcs(kv, k, v)
    return dict(map(inner, d.items()))


def filter_dict(d: dict, *, kv=None, k=None, v=None):
    inner_pred = handle_unary_funcs(kv, k, v)
    return dict(filter(inner_pred, d.items()))


def sort_dict(d: dict, *, reverse=False, kv=None, k=None, v=None):
    inner_key = handle_unary_funcs(kv, k, v, default=lambda ikv: ikv[0])
    return dict(sorted(d.items(), key=inner_key, reverse=reverse))


def groupby_dict(d: dict, *, key_kv=None, key_k=None, key_v=None, agg=None):
    inner_key = handle_unary_funcs(key_kv, key_k, key_v)

    grouped_by = {}
    for kv in d.items():
        new_key = inner_key(kv)
        if new_key not in grouped_by:
            grouped_by[new_key] = {}
        key, value = kv
        grouped_by[new_key][key] = value

    if agg is not None:
        return map_dict(grouped_by, v=agg)

    return grouped_by


# Because Python sucks.
def parse_enum_str(enum_str: str, enum_cls):
    value = enum_str.split(".", 1)[1]
    return enum_cls[value]


def get_YYY_files(YYY_dir):
    YYY_files = {}
    for root, dirs, files in os.walk(YYY_dir, topdown=False):
        for name in files:
            if not name.endswith(".py"):
                continue

            rel_path = os.path.join(root, name)
            result_name = rel_path[len(YYY_dir):]
            abs_path = os.path.abspath(rel_path)
            YYY_files[result_name] = abs_path

    return YYY_files


def load_YYY_files(path=None, files=None, vanilla=False, tag=None, prefix=None):
    if files is None:
        files = get_YYY_files(path)

    stores = {}
    for name, path in files.items():
        store = YYY.safe_load(path, exposed_symbols=[np.array], extra_mappings=dict(nan=float("nan"),
                                                                                      inf=float("inf")))
        store["actual_name"] = name
        store["actual_path"] = path
        store["tag"] = tag

        key = f"{prefix}{name}" if prefix is not None else name
        stores[key] = store

        if vanilla:
            return stores
    else:
        return map_dict(stores, v=functools.partial(to_namedtuple, name="Result"))
