from functools import partial, reduce
from itertools import islice, compress
# from itertools import tee
from collections.abc import Iterable, Iterator
import torch
import math

###
# Packing of tensors
def pack_tensors(*tensors, start_dim = 0, end_dim = -1):
    flatten_func = partial(torch.flatten, start_dim = start_dim, end_dim = end_dim)
    return torch.cat(tuple(map(flatten_func, flatteniter(tensors, keep_tensors =
                                                         True))), dim = start_dim)

def upack_tensor_iter(packed_tensor, size_iter):
    index = 0
    for size in size_iter:
        numel = prod(size)
        yield packed_tensor[index:(index + numel)].view(size)
        index += numel

def unpack_tensors(packed_tensor, *sizes):
    return tuple(tensor for tensor in upack_tensor_iter(packed_tensor, sizes))

###
# Multiple unsqueeze
def m_unsqueeze(tensor, dims = None, num = 1):
    if isinstance(dims, int):
        return tensor.unsqueeze(dims)
    elif isinstance(dims, Iterable):
        return reduce(torch.unsqueeze, dims, tensor)
    elif dims is None:
        return reduce(torch.unsqueeze, [-1] * num, tensor)
    raise ValueError("Dims argument for m_unsqueeze not recognized")

def m_unsqueeze_(tensor, dims = None, num = 1):
    inplace_usqueeze = lambda tens, d : tens.unsqueeze_(d)
    if isinstance(dims, int):
        return tensor.unsqueeze_(dims)
    elif isinstance(dims, Iterable):
        return reduce(inplace_usqueeze, dims, tensor)
    elif dims is None:
        return reduce(inplace_usqueeze, [-1] * num, tensor)
    raise ValueError("Dims argument for m_unsqueeze not recognized")

###
# Proper product
def prod(x):
    if isinstance(x, (int, float, complex)):
        return x
    elif not isinstance(x, Iterable):
        raise ValueError("Input to prod is not numeric or an iterable")
    return math.prod(x)

###
# Iterable manipulation convenience functions

# Flatten iterables
def flatteniter(x, levels = None, keep_tensors = True):
    if not isinstance(x, Iterable) or levels == 0:
        yield x
    elif isinstance(x, torch.Tensor) and (keep_tensors or len(x.shape) == 0):
        yield x
    else:
        for elem in x:
            yield from flatteniter(elem) if levels is None else flatteniter(elem, max(levels - 1, 0))

# Apply to the nth member of an iterable and return the result and an equivalent iterable
# def apply_on_nth(iterable, func, n = 0, replace = False):
#     if not isinstance(iterable, Iterator): iterable = iter(iterable)
#     head = (*islice(iterable, n), )
#     func_input, result = (), ()
#     for func_input in islice(iterable, 0, 1):
#         result = (func(func_input),)
#         func_input = (func_input,)
#     if replace:
#         return flatteniter(((*head, *result), iterable), levels = 2), *result
#     else:
#         return flatteniter(((*head, *func_input), iterable), levels = 2), *result

# Peek at the nth element
# def peek_at_nth(iterable, n = 0):
#     return apply_on_nth(iterable, lambda x:x, n = n)

# Apply to the nth member of an iterable and return the result and an equivalent
# iterable with it's nth element replaced
# def map_on_nth(func, iterable, n = 0):
#     return (new_iterable for new_iterable, *res in map(partial(apply_on_nth, func = func, n = n, replace =
#                                                                True), iterable))

# Apply to the nth member of an iterable and return the result and an equivalent iterable
# def filter_on_nth(func, iterable, n = 0):
#     filter_results = filter(lambda outputs : outputs[1] if len(outputs) == 2 else
#                             False,
#                             (apply_on_nth(inner_iterabl,
#                                           func, n = n) for
#                              inner_iterabl in
#                              iterable))
#     return (new_iterable for new_iterable, *res in filter_results)

###
# Boilerplate functions for checking values

# Convenience function to perform boilerplate replacement with a default value
def check(val, func = None, default = None):
    test = func(val) if func is not None else bool(val)
    new_val = val if test else default
    return (new_val, test)

###
# General function manipulation

# Convenience function to curry a sequence of functions
def curry(*funcs_and_args):
    return [partial(func, **dict(*args)) for func, *args in
            ensuretuples(funcs_and_args)]

# Convenience function to perform boilerplate composition of functions
def compose(*funcs):
    def composition_func(x):
        res = x
        for func in funcs[::-1]:
            res = func(res)
        return res
    return composition_func

###
# List convenience functions

# Check is and object is a list
def islist(obj):
    return isinstance(obj, list)

# Guarantee object(s) is a list or wrapped in a list
def ensurelist(first_elem, *rest_elems):
    if not rest_elems:
        new_val, *_ = check(first_elem, islist, [first_elem])
        return new_val
    else:
        return [first_elem, *rest_elems]

# Flatten an iterator to a list
def flattentolist(iterator, levels = None):
    return [elem for elem in flatteniter(iterator, levels)]

###
# Tuple convenience functions

# Check is and object is a tuple
def istuple(obj):
    return isinstance(obj, tuple)

# Guarantee object(s) is a tuple or wrapped in a tuple
def ensuretuple(first_elem, *rest_elems):
    if not rest_elems:
        new_val, *_ = check(first_elem, istuple, (first_elem,))
        return new_val
    else:
        return (first_elem, *rest_elems)

# Guarantee make sure an iterable contains tuples
def ensuretuples(iterable):
    return map(ensuretuple, iterable)

# Get the components of a iterable of tuples
def components(iterable):
    return zip(*iterable)

# Get the nth component of a iterable of tuples
def component(iterable, n):
    # Could also use map and itemgetter from operators
    return (elem[n] for elem in iterable)

# Append a component to a iterable of tuples
def merge(iterable1, iterable2):
    iterable1, iterable2 = ensuretuples(iterable1), ensuretuples(iterable2)
    return ((*elem1, *elem2) for elem1, elem2 in zip(iterable1, iterable2))

# Flatten an iterator to a tuple
def flattentotuple(iterator, levels = None):
    return (*(elem for elem in flatteniter(iterator, levels)),)

# Check if all values are equal to each other
def allequal(iterable):
    iterator = iter(iterable)
    value = next(iterator)
    for elem in iterator:
        if elem != value: return False
    return True

# Check if values are sorted
def issorted(iterable):
    prev_value = None
    for count, elem in enumerate(iterable):
        if count == 0: prev_value = elem
        elif prev_value > elem: return False
        prev_value = elem
    return True

# Convenience method used to filter out trivial/empty chunks
def convert_to_3gram(elem, index=0):
    if isinstance(elem, (list, tuple)):
        return (elem[0:index], elem[index:index+1], elem[index+1:len(elem)])
    else:
        return ((),(),())

# Convenience method to restore from 3 grams
def restore_from_3gram(elem):
    if isinstance(elem[0][0:0], list):
        return [*elem[0], *elem[1], *elem[2]]
    return (*elem[0], *elem[1], *elem[2])

# Convenience method to restore iterable of 3 grams
def from_3gram_iterable(iterable):
    return map(lambda tple : restore_from_3gram(tple), iterable)

# Convenience method used to filter out trivial/empty chunks
def non_trivial_3grams(iterable, index=0):
    return filter(lambda tple: tple[1] != (), (convert_to_3gram(elem, index =
                                                                index) for elem in iterable))

# Version of filter that works with tuples or indexable collections of data
def extra_info_filter(func, iterable, index=0):
    return from_3gram_iterable(filter(lambda tple : func(tple[1]),
                                      non_trivial_3grams(iterable, index)))
# Returns a map object, unfortunately, not a filter one

# Version of map that works with tuples or indexable collections of data
def extra_info_map(func, iterable, index=0):
    return from_3gram_iterable(map(lambda tple : (tple[0], (func(*tple[1]),),
                                                  tple[2]),
                                   non_trivial_3grams(iterable,
                                                      index)))
# Could also use a generator expression
