import torch
from torch.utils.data import Dataset
import numpy as np
from gen_neg_toy.utils.random import RNG

from .ng_utils import compute_infraction


class Checkerboard(Dataset):
    def __init__(self, n_samples, seed=123, slack=None):
        with RNG(seed):
            if slack is not None:
                slack *= 2 # Data is generated in [-2, 2] but rescaled to [-1, 1]

            def sample(n_samples):
                x1 = np.random.rand(n_samples) * 4 - 2
                x2_ = np.random.rand(n_samples) - np.random.randint(0, 2, n_samples) * 2
                x2 = x2_ + (np.floor(x1) % 2)
                return np.concatenate([x1[:, None], x2[:, None]], 1)
            def is_accepted(data):
                if slack is None or slack == 0:
                    return np.ones(len(data), dtype=bool)
                grids = np.arange(-2, 3)
                x1, x2 = data[:, 0], data[:, 1]
                rejected = np.zeros(len(data), dtype=bool)
                for x in [x1, x2]:
                    # Compute pair-wise distances between x and grids
                    dists = np.abs(x[:, None] - grids[None, :]) # (n_samples, 5)
                    rejected = np.logical_or(rejected, np.any(dists < slack, axis=1))
                return np.logical_not(rejected)

            data = None
            while data is None or len(data) < n_samples:
                new_data = sample(n_samples)
                new_data = new_data[is_accepted(new_data)]
                data = new_data if data is None else np.concatenate([data, new_data])
            data = data[:n_samples]
            data /= 2
            self.data = data

            np.random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]


class ValidatedDataset(Dataset):
    def __init__(self, base, cache_labels=True, slack=None, labels=None):
        self.base = base
        super().__init__()
        if labels is not None:
            assert cache_labels == True, "If labels are provided, cache_labels must be True"
        if cache_labels:
            if labels is None:
                labels = np.zeros((len(base),))
                max_batch_size = 100000

                for i in range(0, len(base), max_batch_size):
                    labels[i : i + max_batch_size] = (
                        ~compute_infraction(base[i : i + max_batch_size], slack=slack)
                    ).float()
                self.labels = labels
            else:
                self.labels = np.ones((len(base),)) * labels
        else:
            self.labels = None

    def __getitem__(self, index):
        item = self.base[index]
        if self.labels is not None:
            label = self.labels[index]
        else:
            label = ~compute_infraction(item.unsqueeze(0)).squeeze(0)
        return item, label

    def __len__(self):
        return len(self.base)


class SyntheticDataset(ValidatedDataset):
    def __init__(self, path, size=None, enforce_labels=None, slack=None, labels=None):
        assert enforce_labels is None or enforce_labels in ["+", "-"]
        with open(path, "rb") as f:
            base = np.load(f)
        if size is not None:
            assert (
                len(base) >= size
            ), f"Expected the negative dataset to be at least {size} samples long but only found {len(base)} samples in the dataset at {path}"
            base = base[:size]
        super().__init__(base, slack=slack, labels=labels)
        if enforce_labels=="-":
            assert np.all(
                self.labels == 0
            ), f"Expected all the samples in a negative dataset to have infraction but found {self.labels.sum()} infraction-free samples in the dataset at {path}"
        elif enforce_labels=="+":
            assert np.all(
                self.labels == 1
            ), f"Expected all the samples in a positive dataset to have infraction but found {self.labels.sum()} infractionful samples in the dataset at {path}"
        mode_str = {'+': 'positive', '-': 'negative', None: ''}[enforce_labels]
        print(f"Loaded {len(self)} {mode_str} examples from {path}.")


def get_train_set(dataset_name, n_samples=1000000, slack=None):
    assert dataset_name == "checkerboard", "The only supported dataset is checkerboard."
    return ValidatedDataset(
        Checkerboard(n_samples=n_samples, seed=123, slack=slack),
    )


def get_test_set(dataset_name, n_samples=10000, slack=None):
    assert dataset_name == "checkerboard", "The only supported dataset is checkerboard."
    return ValidatedDataset(
        Checkerboard(n_samples=n_samples, seed=456, slack=slack),
    )


def get_datasets(data_config):
    slack = getattr(data_config, "slack", None)
    if data_config.train_set_size is not None:
        train_set = get_train_set(data_config.dataset, n_samples=data_config.train_set_size, slack=slack)
    else:
        train_set = get_train_set(data_config.dataset)
    test_set = get_test_set(data_config.dataset, n_samples=1000, slack=slack)
    return train_set, test_set


def merge_neg_dataset(base_dataset, neg_dataset_path, neg_dataset_size):
    neg_dataset = SyntheticDataset(neg_dataset_path, neg_dataset_size, enforce_labels="-")
    return torch.utils.data.ConcatDataset([base_dataset, neg_dataset])
