import time
from typing import *
from pathlib import Path


def catch_exception(fn):
    def wrapper(*args, **kwargs):
        try:
            return fn(*args, **kwargs)
        except Exception as e:
            import traceback
            print(f"Exception in {fn.__name__}({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})")
            traceback.print_exc(chain=False)
            time.sleep(0.1)
            return None
    return wrapper


class CallbackOnException:
    def __init__(self, callback: Callable, exception: type):
        self.exception = exception
        self.callback = callback

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if isinstance(exc_val, self.exception):
            self.callback()
            return True
        return False
    
def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]:
    for k, v in d.items():
        if isinstance(v, dict):
            for sub_key in traverse_nested_dict_keys(v):
                yield (k, ) + sub_key
        else:
            yield (k, )


def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None):
    for k in keys:
        d = d.get(k, default)
        if d is None:
            break
    return d

def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any):
    for k in keys[:-1]:
        d = d.setdefault(k, {})
    d[keys[-1]] = value


def key_average(list_of_dicts: list) -> Dict[str, Any]:
    """
    Returns a dictionary with the average value of each key in the input list of dictionaries.
    """
    _nested_dict_keys = set()
    for d in list_of_dicts:
        _nested_dict_keys.update(traverse_nested_dict_keys(d))
    _nested_dict_keys = sorted(_nested_dict_keys)
    result = {}
    for k in _nested_dict_keys:
        values = [
            get_nested_dict(d, k) for d in list_of_dicts
            if get_nested_dict(d, k) is not None
        ]
        avg = sum(values) / len(values) if values else float('nan')
        set_nested_dict(result, k, avg)
    return result


def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]:
    """
    Flattens a nested dictionary into a single-level dictionary, with keys as tuples.
    """
    items = []
    if parent_key is None:
        parent_key = ()
    for k, v in d.items():
        new_key = parent_key + (k, )
        if isinstance(v, MutableMapping):
            items.extend(flatten_nested_dict(v, new_key).items())
        else:
            items.append((new_key, v))
    return dict(items)


def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]:
    """
    Unflattens a single-level dictionary into a nested dictionary, with keys as tuples.
    """
    result = {}
    for k, v in d.items():
        sub_dict = result
        for k_ in k[:-1]:
            if k_ not in sub_dict:
                sub_dict[k_] = {}
            sub_dict = sub_dict[k_]
        sub_dict[k[-1]] = v
    return result


def read_jsonl(file):
    import json
    with open(file, 'r') as f:
        data = f.readlines()
    return [json.loads(line) for line in data]


def write_jsonl(data: List[dict], file):
    import json
    with open(file, 'w') as f:
        for item in data:
            f.write(json.dumps(item) + '\n')


def save_metrics(save_path: Union[str, Path], all_metrics: Dict[str, List[Dict]]):
    import pandas as pd
    import json
    
    with open(save_path, 'w') as f:
        json.dump(all_metrics, f, indent=4)


def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]):
    import pandas as pd
    data = [flatten_nested_dict(d) for d in data]
    df = pd.DataFrame(data)
    df = df.sort_index(axis=1)
    df.columns = pd.MultiIndex.from_tuples(df.columns)  
    return df


def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]):
    if isinstance(d, str):
        for old, new in mapping.items():
            d = d.replace(old, new)
    elif isinstance(d, list):
        for i, item in enumerate(d):
            d[i] = recursive_replace(item, mapping)
    elif isinstance(d, dict):
        for k, v in d.items():
            d[k] = recursive_replace(v, mapping)
    return d


class timeit:
    _history: Dict[str, List['timeit']] = {}

    def __init__(self, name: str = None, verbose: bool = True, multiple: bool = False):
        self.name = name
        self.verbose = verbose
        self.start = None
        self.end = None
        self.multiple = multiple
        if multiple and name not in timeit._history:
            timeit._history[name] = []

    def __call__(self, func: Callable):
        import inspect
        if inspect.iscoroutinefunction(func):
            async def wrapper(*args, **kwargs):
                with timeit(self.name or func.__qualname__):
                    ret = await func(*args, **kwargs)
                return ret
            return wrapper
        else:
            def wrapper(*args, **kwargs):
                with timeit(self.name or func.__qualname__):
                    ret = func(*args, **kwargs)
                return ret
            return wrapper
        
    def __enter__(self):
        self.start = time.time()

    @property
    def time(self) -> float:
        assert self.start is not None, "Time not yet started."
        assert self.end is not None, "Time not yet ended."
        return self.end - self.start

    @property
    def history(self) -> List['timeit']:
        return timeit._history.get(self.name, [])

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.end = time.time()
        if self.multiple:
            timeit._history[self.name].append(self)
        if self.verbose:
            if self.multiple:
                avg = sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name])
                print(f"{self.name or 'It'} took {avg} seconds in average.")
            else:
                print(f"{self.name or 'It'} took {self.time} seconds.")


def strip_common_prefix_suffix(strings: List[str]) -> List[str]:
    first = strings[0]

    for start in range(len(first)):
        if any(s[start] != strings[0][start] for s in strings):
            break

    for end in range(1, min(len(s) for s in strings)):
        if any(s[-end] != first[-end] for s in strings):
            break

    return [s[start:len(s) - end + 1] for s in strings]


def multithead_execute(inputs: List[Any], num_workers: int, pbar = None):
    from concurrent.futures import ThreadPoolExecutor
    from contextlib import nullcontext
    from tqdm import tqdm

    if pbar is not None:
        pbar.total = len(inputs) if hasattr(inputs, '__len__') else None
    else:
        pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None)

    def decorator(fn: Callable):
        with (
            ThreadPoolExecutor(max_workers=num_workers) as executor,
            pbar
        ):  
            pbar.refresh()
            @catch_exception
            def _fn(input):
                ret = fn(input)
                pbar.update()
                return ret
            executor.map(_fn, inputs)
            executor.shutdown(wait=True)
    
    return decorator