import random
import yaml
import numpy as np
import torchvision.transforms as transforms
from PIL import ImageFilter
from torch.utils.data import DataLoader

from .backdoor import *
from .cifar import CIFAR10
from .prefetch import PrefetchLoader


class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR.

    Borrowed from https://github.com/facebookresearch/moco/blob/master/moco/loader.py.
    """

    def __init__(self, sigma=[0.1, 2.0]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))

        return x


def query_transform(name, kwargs):
    if name == "random_crop":
        return transforms.RandomCrop(**kwargs)
    elif name == "random_resize_crop":
        return transforms.RandomResizedCrop(**kwargs)
    elif name == "resize":
        return transforms.Resize(**kwargs)
    elif name == "center_crop":
        return transforms.CenterCrop(**kwargs)
    elif name == "random_horizontal_flip":
        return transforms.RandomHorizontalFlip(**kwargs)
    elif name == "random_color_jitter":
        # In-place!
        p = kwargs.pop("p")
        return transforms.RandomApply([transforms.ColorJitter(**kwargs)], p=p)
    elif name == "random_grayscale":
        return transforms.RandomGrayscale(**kwargs)
    elif name == "gaussian_blur":
        # In-place!
        p = kwargs.pop("p")
        return transforms.RandomApply([GaussianBlur(**kwargs)], p=p)
    elif name == "to_tensor":
        if kwargs:
            return transforms.ToTensor()
    elif name == "normalize":
        return transforms.Normalize(**kwargs)
    else:
        raise ValueError("Transformation {} is not supported!".format(name))


def get_transform(transform_config):
    transform = []
    if transform_config is not None:
        for k, v in transform_config.items():
            if v is not None:
                transform.append(query_transform(k, v))
    transform = transforms.Compose(transform)

    return transform


def get_dataset(dataset_dir, transform, train=True, prefetch=False, selected_classes=None):
    if "cifar" in dataset_dir:
        dataset = CIFAR10(dataset_dir, transform=transform, train=train, prefetch=prefetch)
    else:
        raise ValueError("Dataset in {} is not supported.".format(dataset_dir))

    return dataset


def get_loader(dataset, loader_config=None, **kwargs):
    if loader_config is None:
        loader = DataLoader(dataset, **kwargs)
    else:
        loader = DataLoader(dataset, **loader_config, **kwargs)
    if dataset.prefetch:
        loader = PrefetchLoader(loader, dataset.mean, dataset.std)

    return loader


def gen_poison_idx(dataset, target_label, poison_ratio=None):
    poison_idx = np.zeros(len(dataset))
    train = dataset.train
    for i, t in enumerate(dataset.targets):
        if train and poison_ratio is not None:
            if random.random() < poison_ratio and t != target_label:
                poison_idx[i] = 1
        else:
            if t != target_label:
                poison_idx[i] = 1

    return poison_idx


def gen_noise_idx(dataset, noise_ratio, poison_idx):
    noise_idx = np.zeros(len(dataset))
    num_noise = int(len(dataset) * noise_ratio)
    while num_noise > 0:
        for i in range(len(noise_idx)):
            if random.random() < noise_ratio and poison_idx[i] == 0 and noise_idx[i] == 0:
                noise_idx[i] = 1
                num_noise -= 1
                if num_noise == 0:
                    break

    return noise_idx


def get_bd_transform(bd_config):
    if bd_config["attack_type"] == "badnets":
        print("using badnets attack")
        print("====================================")
        bd_transform = BadNets(bd_config["badnets"]["trigger_path"])
    else:
        raise ValueError("Backdoor {} is not supported.".format(bd_config))

    return bd_transform


class NormalizeByChannelMeanStd(nn.Module):
    """Normalizing the input to the network."""

    def __init__(self, mean, std):
        super(NormalizeByChannelMeanStd, self).__init__()
        if not isinstance(mean, torch.Tensor):
            mean = torch.tensor(mean)
        if not isinstance(std, torch.Tensor):
            std = torch.tensor(std)
        self.register_buffer("mean", mean)
        self.register_buffer("std", std)

    def forward(self, tensor):
        mean = self.mean[None, :, None, None]
        std = self.std[None, :, None, None]
        return tensor.sub(mean).div(std)

    def extra_repr(self):
        return "mean={}, std={}".format(self.mean, self.std)


def load_config(config_path):
    """Load config file from `config_path`.

    Args:
        config_path (str): Configuration file path, which must be in `config` dir, e.g.,
            `./config/inner_dir/example.yaml` and `config/inner_dir/example`.

    Returns:
        config (dict): Configuration dict.
        inner_dir (str): Directory between `config/` and configuration file. If `config_path`
           doesn't contain `inner_dir`, return empty string.
        config_name (str): Configuration filename.
    """
    assert os.path.exists(config_path)
    config_hierarchy = config_path.split("/")
    if config_hierarchy[0] != ".":
        if config_hierarchy[0] != "config":
            raise RuntimeError("Configuration file {} must be in config dir".format(config_path))
        if len(config_hierarchy) > 2:
            inner_dir = os.path.join(*config_hierarchy[1:-1])
        else:
            inner_dir = ""
    else:
        if config_hierarchy[1] != "config":
            raise RuntimeError("Configuration file {} must be in config dir".format(config_path))
        if len(config_hierarchy) > 3:
            inner_dir = os.path.join(*config_hierarchy[2:-1])
        else:
            inner_dir = ""
    print("Load configuration file from {}:".format(config_path))
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    config_name = config_hierarchy[-1].split(".yaml")[0]

    return config, inner_dir, config_name
