"""
Misc shared utility functions
"""
import os
import glob
import hashlib
import logging
from typing import *
import ast
import argparse
from argparse import Namespace
import pickle
import requests
import re
import numpy as np

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")

def str2dict(v):
    m = re.match('\d+-\d+(?::\d+-\d+)*$', v)
    if not m:
        raise argparse.ArgumentTypeError("Wrong format, see help.")
    pairs = re.findall(r'(\d+)-(\d+)', v)        
    bins = {}
    for (a, b) in pairs:
        bins[int(a)] = int(b)
    return bins

def str2dictorint(v):
    if v.isdigit():
        return int(v)
    else:
        return str2dict(v)

def int_or_inf(x: str):
    # allow case‐insensitive “inf”
    if x.lower() in ("inf", "infinity"):
        return float("inf")
    try:
        return int(x)
    except ValueError:
        raise argparse.ArgumentTypeError(f"‘{x}’ is not an integer or ‘inf’")
    

def load_args_from_txt(args_path: str) -> Namespace:
    """
    Read a file of lines "name: value" and return a Namespace
    with attributes name=value (with Python‐literal parsing).
    """
    ns = Namespace()
    with open(args_path, "r") as f:
        for raw in f:
            line = raw.rstrip("\n")
            if not line.strip() or ": " not in line:
                continue
            name, val_str = line.split(": ", 1)
            # try to parse literal (numbers, dicts, booleans, etc.)
            try:
                val = ast.literal_eval(val_str)
            except Exception:
                # fallback to string
                val = val_str
            setattr(ns, name, val)
    return ns

def validate_args_match(current: Union[Namespace, Dict[str, Any]], loaded: Union[Namespace, Dict[str, Any]], skip: list = None):
    """
    Compare every attribute/key in `loaded` to the same in `current`,
    raising an AssertionError if any differ (except those in skip).
    Supports both Namespace and dict as input.
    """
    skip = set(skip or [])

    def get_item(obj, name):
        if isinstance(obj, dict):
            return obj[name]
        else:
            return getattr(obj, name)

    def has_item(obj, name):
        if isinstance(obj, dict):
            return name in obj
        else:
            return hasattr(obj, name)

    loaded_items = loaded.items() if isinstance(loaded, dict) else vars(loaded).items()
    for name, loaded_val in loaded_items:
        if name in skip:
            continue

        if not has_item(current, name):
            raise AssertionError(f"Current args has no field '{name}'")

        # directory/files
        cur_val = get_item(current, name)
        # normalize directories to absolute paths        
        if name.endswith("_dir") or name.endswith("_path"):
            loaded_val = os.path.abspath(loaded_val)
            cur_val    = os.path.abspath(cur_val)

        if cur_val != loaded_val:
            raise AssertionError(
                f"Mismatch for '{name}': current={cur_val!r}  saved={loaded_val!r}"
            )


def pickle_copy(t):
    return pickle.loads(pickle.dumps(t, protocol=pickle.HIGHEST_PROTOCOL))


def is_huggingface_hub_id(s: str) -> bool:
    r = requests.get(f"https://huggingface.co/{s}")
    return r.status_code == 200


def extract(a, t, x_shape):
    """
    Return the t-th item in a for each item in t
    """
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)


def num_to_groups(num: int, divisor: int) -> List[int]:
    """
    Generates a list of ints of value at most divisor that sums to

    >>> num_to_groups(18, 16)
    [16, 2]
    >>> num_to_groups(33, 8)
    [8, 8, 8, 8, 1]
    """
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    assert sum(arr) == num
    return arr


def seq_to_groups(seq:Sequence[Any], divisor:int) -> List[Sequence[Any]]:
    """
    Generates a list of items of at most <divisor> items
    >>> seq_to_groups([1,2,3,4,5,6,7,8,9], 3)
    [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
    >>> seq_to_groups([1,2,3,4,5,6,7,8,9], 4)
    [[1, 2, 3, 4], [5, 6, 7, 8], [9]]
    """
    return [seq[i:i+divisor] for i in range(0, len(seq), divisor)]


def tolerant_comparison_check(values, cmp: Literal[">=", "<="], v):
    """
    Compares values in a way that is tolerant of numerical precision
    >>> tolerant_comparison_check(-3.1415927410125732, ">=", -np.pi)
    True
    """
    if cmp == ">=":  # v is a lower bound
        minval = np.nanmin(values)
        diff = minval - v
        if np.isclose(diff, 0, atol=1e-5):
            return True  # Passes
        return diff > 0
    elif cmp == "<=":
        maxval = np.nanmax(values)
        diff = maxval - v
        if np.isclose(diff, 0, atol=1e-5):
            return True
        return diff < 0
    else:
        raise ValueError(f"Illegal comparator: {cmp}")


def modulo_with_wrapped_range(
    vals, range_min: float = -np.pi, range_max: float = np.pi
):
    """
    Modulo with wrapped range -- capable of handing a range with a negative min

    >>> modulo_with_wrapped_range(3, -2, 2)
    -1
    """
    assert range_min <= 0.0
    assert range_min < range_max

    # Modulo after we shift values
    top_end = range_max - range_min
    # Shift the values to be in the range [0, top_end)
    vals_shifted = vals - range_min
    # Perform modulo
    vals_shifted_mod = vals_shifted % top_end
    # Shift back down
    retval = vals_shifted_mod + range_min

    # Checks
    # print("Mod return", vals, " --> ", retval)
    # if isinstance(retval, torch.Tensor):
    #     notnan_idx = ~torch.isnan(retval)
    #     assert torch.all(retval[notnan_idx] >= range_min)
    #     assert torch.all(retval[notnan_idx] < range_max)
    # else:
    #     assert (
    #         np.nanmin(retval) >= range_min
    #     ), f"Illegal value: {np.nanmin(retval)} < {range_min}"
    #     assert (
    #         np.nanmax(retval) <= range_max
    #     ), f"Illegal value: {np.nanmax(retval)} > {range_max}"
    return retval


def update_dict_nonnull(d: Dict[str, Any], vals: Dict[str, Any]) -> Dict[str, Any]:
    """
    Update a dictionary with values from another dictionary.
    >>> update_dict_nonnull({'a': 1, 'b': 2}, {'b': 3, 'c': 4})
    {'a': 1, 'b': 3, 'c': 4}
    """
    for k, v in vals.items():
        if k in d:
            if d[k] != v and v is not None:
                logging.info(f"Replacing key {k} original value {d[k]} with {v}")
                d[k] = v
        else:
            d[k] = v
    return d


def md5_all_py_files(dirname: str) -> str:
    """Create a single md5 sum for all given files"""
    # https://stackoverflow.com/questions/36099331/how-to-grab-all-files-in-a-folder-and-get-their-md5-hash-in-python
    fnames = glob.glob(os.path.join(dirname, "*.py"))
    hash_md5 = hashlib.md5()
    for fname in sorted(fnames):
        with open(fname, "rb") as f:
            for chunk in iter(lambda: f.read(2**20), b""):
                hash_md5.update(chunk)
    return hash_md5.hexdigest()


if __name__ == "__main__":
    import doctest

    doctest.testmod()
