import collections.abc
import copy
import random

import numpy as np
import pytorch_lightning as pl
import torch


def rec_update(d, u):
    for k, v in u.items():
        if isinstance(v, collections.abc.Mapping):
            d[k] = rec_update(d.get(k, {}), v)
        else:
            d[k] = v
    return d


def build_config(configs, names):
    # generate postorder
    post = []
    vis = set()

    def postorder(config):
        if config in vis:
            return
        vis.add(config)

        config = copy.deepcopy(getattr(configs, config))
        if "requires" in config:
            for name in config["requires"]:
                postorder(name)
            del config["requires"]
        post.append(config)

    for c in names:
        postorder(c)

    # glue names
    newname = "_".join([c["name"] for c in post if "name" in c])

    # merge configs together
    cfg = {}
    for c in post:
        cfg = rec_update(cfg, c)

    cfg["name"] = newname
    return cfg


class Cutout:
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        if self.length == 0:
            return img
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1:y2, x1:x2] = 0.0
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img


class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets
        self.shuffle_indexes = [np.random.permutation(len(d)) for d in datasets]

    def __getitem__(self, i):
        return tuple(d[s[i]] for s, d in zip(self.shuffle_indexes, self.datasets))

    def __len__(self):
        return min(len(d) for d in self.datasets)


def drop(x, prob):
    keep_prob = 1 - prob
    mask = torch.bernoulli(x.new_full([x.shape[0], 1, 1, 1], keep_prob))
    return x * (mask / keep_prob)


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    pl.seed_everything(seed, workers=True)
