import torch


@torch.no_grad()
def f1_score(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
    """Calculate F1 score. Can work with gpu tensors

    The original implmentation is written by Michal Haltuf on Kaggle.

    Returns
    -------
    torch.Tensor
        `ndim` == 1. 0 <= val <= 1

    Reference
    ---------
    - https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric
    - https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score
    - https://discuss.pytorch.org/t/calculating-precision-recall-and-f1-score-in-case-of-multi-label-classification/28265/6

    """
    assert y_true.ndim == 1
    assert y_pred.ndim == 1 or y_pred.ndim == 2

    if y_pred.ndim == 2:
        y_pred = y_pred.argmax(dim=1)

    tp = (y_true * y_pred).sum().to(torch.float32)
    fp = ((1 - y_true) * y_pred).sum().to(torch.float32)
    fn = (y_true * (1 - y_pred)).sum().to(torch.float32)

    epsilon = 1e-7

    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)

    f1 = 2 * (precision * recall) / (precision + recall + epsilon)
    return f1


def require_lock(locks, name):
    def decorator(f):
        def locked_f(*args, **kwargs):
            with locks[name]:
                ret = f(*args, **kwargs)
            return ret

        return locked_f

    return decorator


def map_dict(f, d, map_keys=False):
    if not isinstance(d, dict):
        raise ValueError(f"Wrong type {type(d)}")
    if map_keys:
        return {f(k): v for k, v in d.items()}
    else:
        return {k: f(v) for k, v in d.items()}


def filter_dict(f, d, filter_keys=False):
    if not isinstance(d, dict):
        raise ValueError(f"Wrong type {type(d)}")
    if filter_keys:
        return {k: v for k, v in d.items() if f(k)}
    else:
        return {k: v for k, v in d.items() if f(v)}


def flatten(n):
    if isinstance(n, tuple) or isinstance(n, list):
        for sn in n:
            yield from flatten(sn)
    elif isinstance(n, dict):
        for key in n:
            yield from flatten(n[key])
    else:
        yield n


def zip(*nests):
    n0, *nests = nests
    iters = [flatten(n) for n in nests]

    def f(first):
        return [first] + [next(i) for i in iters]

    return map(f, n0)


def map_many(f, *nests):
    n0, *nests = nests
    iters = [flatten(n) for n in nests]

    def g(first):
        return f([first] + [next(i) for i in iters])

    return map(g, n0)


class Batcher:
    def __init__(self, batch_size=None, _approx_max_size=1000):
        self.batch_size = batch_size
        self._batch = {}
        self.size = 0
        self._approx_max_size = _approx_max_size

    def ready(self, batch_size=None):
        if batch_size is None:
            return self.size >= self.batch_size
        else:
            return self.size >= batch_size

    def append(self, batch, mask=None, device=None):
        if mask is not None:
            append_size = mask.sum().item()
        else:
            append_size = self._dict_size(batch)

        if self.size >= self._approx_max_size:
            # Discard items of the current size
            size_to_discard = min(self.size, append_size)
            self.size -= size_to_discard
        else:
            size_to_discard = 0

        self.size += append_size

        for key, t in batch.items():
            if mask is not None:
                t = t[mask]
            if device is not None:
                t = t.to(device=device)

            # FIXME: careful w.r.t. memory leaks since we slice into batch.
            # Should be fine though.
            if key in self._batch:
                self._batch[key] = torch.cat((self._batch[key][size_to_discard:], t), 0)
            else:
                self._batch[key] = t

    def get_batch(self, batch_size=None):
        assert self.ready(batch_size), "not ready"
        if batch_size is None:
            batch_size = self.batch_size

        batch = {k: v[:batch_size] for k, v in self._batch.items()}
        batch_resid = {k: v[batch_size:] for k, v in self._batch.items()}
        self.size -= batch_size
        # Clone the tensors, since they are views right now.
        if self.size != 0:  # There remains something in batch residual
            assert self._dict_size(batch_resid) == self.size
            # Clone the tensors, since they are currently views.
            self._batch = {k: v.clone() for k, v in batch_resid.items()}
        else:
            self._batch = {}
        return batch

    def _dict_size(self, d):
        """Get the size of the input dict assuming all values have the same shape"""
        if len(d) == 0:
            raise ValueError("Empty dictionary")
        first_key = next(iter(d.keys()))
        first_shape = d[first_key].shape[0]
        assert all(v.shape[0] == first_shape for v in d.values()), "unequal shapes"
        return first_shape
