# This is partly stolen from https://github.com/elcorto/psweep/blob/master/psweep/psweep.py
from typing import Iterable, Dict, Text, Union, Callable, List

import itertools

ParamType = Union[int, float, str, bool]
ArgDict = Dict[Text, ParamType]
ParamList = List[ArgDict]
Condition = Callable[[ArgDict], bool]
ParamFunc = Callable[[ArgDict], Union[ParamType, Iterable[ParamType]]]


def listify(func):
  """Decorator which turns an iterator return type into a list."""
  def wrapper(*args) -> List:
    return list(func(*args))
  return wrapper


def itr(func):
  """Decorator which makes functions take a sequence of args or individual
  args.
  ::
    @itr
    def func(seq):
        for arg in seq:
            ...
    @itr
    def func(*args):
        for arg in args:
            ...
  """
  def wrapper(*args):
    if len(args) == 1:
      return func(args[0])
    else:
      return func(args)
  return wrapper


def is_seq(seq):
  if isinstance(seq, (str, dict)):
    return False
  else:
    try:
      iter(seq)
      return True
    except TypeError:
      return False


def flatten(seq):
  for item in seq:
    if is_seq(item):
      for subitem in flatten(item):
        yield subitem
    else:
      yield item


def plist(name, seq):
  """Create a list of single-item dicts holding the parameter name and a value.
  >>> plist('a', [1,2,3])
  [{'a': 1}, {'a': 2}, {'a': 3}]
  """
  return [{name: entry} for entry in seq]


@itr
def merge_dicts(args):
  """Start with an empty dict and update with each arg dict
  left-to-right."""
  dct = {}
  for entry in args:
    dct.update(entry)
  return dct


@itr
def itr2params(loops):
  """Transform the (possibly nested) result of a loop over plists (or
  whatever has been used to create psets) to a proper list of psets
  by flattening and merging dicts.
  Example
  -------
  >>> a = plist('a', [1,2])
  >>> b = plist('b', [77,88])
  >>> c = plist('c', ['const'])
  # result of loops
  >>> list(itertools.product(a,b,c))
  [({'a': 1}, {'b': 77}, {'c': 'const'}),
   ({'a': 1}, {'b': 88}, {'c': 'const'}),
   ({'a': 2}, {'b': 77}, {'c': 'const'}),
   ({'a': 2}, {'b': 88}, {'c': 'const'})]
  # flatten into list of psets
  >>> itr2params(itertools.product(a,b,c))
  [{'a': 1, 'b': 77, 'c': 'const'},
   {'a': 1, 'b': 88, 'c': 'const'},
   {'a': 2, 'b': 77, 'c': 'const'},
   {'a': 2, 'b': 88, 'c': 'const'}]
  # also more nested stuff is no problem
  >>> list(itertools.product(zip(a,b),c))
  [(({'a': 1}, {'b': 77}), {'c': 'const'}),
   (({'a': 2}, {'b': 88}), {'c': 'const'})]
  >>> itr2params(itertools.product(zip(a,b),c))
  [{'a': 1, 'b': 77, 'c': 'const'},
   {'a': 2, 'b': 88, 'c': 'const'}]
  """
  return [merge_dicts(flatten(entry)) for entry in loops]


@itr
def pgrid(plists):
  """Convenience function for the most common loop: nested loops with
  ``itertools.product``: ``ps.itr2params(itertools.product(a,b,c,...))``.
  >>> pgrid(a,b,c)
  [{'a': 1, 'b': 77, 'c': 'const'},
   {'a': 1, 'b': 88, 'c': 'const'},
   {'a': 2, 'b': 77, 'c': 'const'},
   {'a': 2, 'b': 88, 'c': 'const'}]
  >>> pgrid(zip(a,b),c)
  [{'a': 1, 'b': 77, 'c': 'const'},
   {'a': 2, 'b': 88, 'c': 'const'}]
  """
  return itr2params(itertools.product(*plists))


def always_true(arg: ArgDict) -> bool:
  return True


def check_condition(cond: Condition, arg: ArgDict) -> bool:
  try:
    return cond(arg)
  except KeyError:
    return False


@listify
def add_dependent(
    plist: ParamList, name: Text, func: ParamFunc, cond: Condition = None
) -> ParamList:
  for arg in plist:
    if cond is None or check_condition(cond, arg):
      new = func(arg)
      if is_seq(new):
        for val in new:
          yield dict(arg, **{name: val})
      else:
        yield dict(arg, **{name: new})
    else:
      yield arg


@listify
def add_grid(
    plist: ParamList, subplist: ParamList, cond: Condition
) -> ParamList:
  for arg in plist:
    if check_condition(cond, arg):
      for subarg in subplist:
        yield dict(arg, **subarg)
    else:
      yield arg


@listify
def prune(plist: ParamList, cond: Condition) -> ParamList:
  for arg in plist:
    if not check_condition(cond, arg):
        yield arg


def add_const(plist: ParamList, const: ArgDict) -> ParamList:
  return [dict(arg, **const) for arg in plist]
