import os
from tqdm import tqdm
from spaghettini import quick_register
import torch
import random

import numpy as np
from scipy.stats import bernoulli
from scipy.signal import convolve2d
from torch.utils.data import Dataset

import matplotlib.pyplot as plt

WHITE_PLUS = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
BLACK_PLUS = np.array([[1, 0, 1], [0, 0, 0], [1, 0, 1]])


@quick_register
class FindThePlusDataset(Dataset):
    def __init__(self, im_size=(10, 10), white_plus_probability=0.5, fill_ratios=(0.5,), epoch_len=50000,
                 dataset_size=(50000, 10000, 10000), save_path="./", normalize_to_unit_interval=True):
        assert isinstance(fill_ratios, tuple) or isinstance(fill_ratios, list)
        self.train_size, self.valid_size, self.test_size = dataset_size
        self.im_width, self.im_height = im_size
        self.white_plus = WHITE_PLUS
        self.black_plus = BLACK_PLUS
        self.white_plus_probability = white_plus_probability
        self.fill_ratios = fill_ratios  # What percent of the image is full of 1s.
        self.save_path = save_path
        self.epoch_len = epoch_len
        self.normalize_to_unit_interval = normalize_to_unit_interval

    def __len__(self):
        return self.epoch_len

    def __getitem__(self, idx):
        # Pick label (i.e. black or white plus).
        label = int(bernoulli.rvs(size=1, p=self.white_plus_probability))

        # Generate image and return.
        return self.generate_sample_image(label=label)

    def generate_sample_image(self, label):
        assert label in (0, 1)
        # Pick fill ratio.
        fill_ratio = random.choice(self.fill_ratios)

        # Fill image with random 0s, 1s. First generate image without label.
        data_bern = bernoulli.rvs(size=self.im_width * self.im_height,
                                  p=fill_ratio)
        img = data_bern.reshape((self.im_width, self.im_height))
        # Remove white and black pluses from the image.
        while _is_white_plus_in_img(img, self.white_plus) or _is_black_plus_in_img(img, self.white_plus):
            img = _remove_white_plus(img, self.white_plus)
            img = _remove_black_plus(img, white_plus=self.white_plus)

        clean_img = np.copy(img)
        cond1 = lambda img_: label == 0 and _is_black_plus_in_img(img_, self.white_plus) and not _is_white_plus_in_img(
            img_, self.white_plus)
        cond2 = lambda img_: label == 1 and not _is_black_plus_in_img(img_, self.white_plus) and _is_white_plus_in_img(
            img_, self.white_plus)
        while not (cond1(img) or cond2(img)):
            img = np.copy(clean_img)
            # ____Add the label. ____
            # Sample the coordinates.
            width = np.random.randint(low=1, high=self.im_width - 1)
            height = np.random.randint(low=1, high=self.im_height - 1)

            # Place the plus.
            if label == 0:
                img[width - 1:width + 2, height - 1:height + 2] = self.black_plus
            else:
                img[width - 1:width + 2, height - 1:height + 2] = self.white_plus

            # Process the coordinates.
            normalized_height = (height - (self.im_height - 1) / 2) / (self.im_height / 2)
            normalized_width = (width - (self.im_width - 1) / 2) / (self.im_width / 2)

        # Normalize to unit inverval. (i.e. pixel values between -1 and 1)
        if self.normalize_to_unit_interval:
            img = 2 * ((img - img.min()) / (img.max() - img.min())) - 1.

        # Return the image, label and the position of the plus.
        return_dict = dict(img=img, label=label, coords=np.array([normalized_height, normalized_width]))

        return return_dict

    def generate_dataset(self, size):
        imgs = np.zeros((size, self.im_width, self.im_height))
        labels = np.zeros((size,))

        for i in tqdm(range(size)):
            if int(bernoulli.rvs(size=1, p=self.white_plus_probability)) == 0:
                imgs[i, :, :], coords = self.generate_sample_image(0)
                labels[i] = 0
            else:
                imgs[i, :, :], coords = self.generate_sample_image(1)
                labels[i] = 1
        return imgs, labels

    def save_dataset(self):
        # Generate training, validation and test sets.
        train_imgs, train_labels = self.generate_dataset(self.train_size)
        valid_imgs, valid_labels = self.generate_dataset(self.valid_size)
        test_imgs, test_labels = self.generate_dataset(self.test_size)

        os.makedirs(self.save_path, exist_ok=True)

        np.save(os.path.join(self.save_path, "train_data.npy"), train_imgs)
        np.save(os.path.join(self.save_path, "train_labels.npy"), train_labels)
        np.save(os.path.join(self.save_path, "valid_data.npy"), valid_imgs)
        np.save(os.path.join(self.save_path, "valid_labels.npy"), valid_labels)
        np.save(os.path.join(self.save_path, "test_data.npy"), test_imgs)
        np.save(os.path.join(self.save_path, "test_labels.npy"), test_labels)


def is_pattern_in_image(img, pattern):
    conv_out = convolve2d(img, pattern, mode='valid')
    if (conv_out == pattern.sum()).sum() > 0:
        return True
    else:
        return False


def _is_white_plus_in_img(img, white_plus):
    return is_pattern_in_image(img, white_plus)


def _is_black_plus_in_img(img, white_plus):
    anti_img = np.abs(1.0 - img)
    return is_pattern_in_image(anti_img, white_plus)


def _remove_white_plus(img, white_plus):
    while is_pattern_in_image(img, white_plus):
        conv_out = convolve2d(img, white_plus, mode='same')
        mask = (conv_out == white_plus.sum()).astype(np.int64)
        img = img - mask
    return img


def _remove_black_plus(img, white_plus):
    anti_img = np.abs(1.0 - img)
    while is_pattern_in_image(anti_img, white_plus):
        conv_out = convolve2d(anti_img, white_plus, mode='same')
        mask = (conv_out == white_plus.sum()).astype(np.int64)
        img = img + mask
        anti_img = np.abs(1.0 - img)

    return img


def _tmp_proof_generator(imgs, coords):
    """
    Extract the 3x3 parts of the plus from the image.
    This snippet is for debugging this function:

    fig, axs = plt.subplots(3, 1)
    axs[0].imshow(imgs[0, 0])
    axs[1].imshow(masks[0, 0])
    axs[2].imshow(proofs[0, 0])
    plt.title(f"coords: ({r},{c})")
    """
    masks = torch.zeros_like(imgs)
    for i in range(imgs.shape[0]):
        r, c = coords[0][i], coords[1][i]

        masks[i, :, r - 1:r + 2, c - 1:c + 2] = 1

    proofs = imgs * masks

    return proofs

########################################################################################################################
# ____ Debugging related functions. ____


def _find_plus_plot_generated_samples(xs, ys, coords):
    fig, axs = plt.subplots(3, 3)
    for i in range(9):
        # Plot the example.
        axs[i // 3, i % 3].imshow(X=xs[i, 0].detach().numpy())
        axs[i // 3, i % 3].set_title(f"label: {ys[i]}")
        axs[i // 3, i % 3].set(xlabel=f"{coords[i]}")

    plt.tight_layout()
    plt.show()

########################################################################################################################


if __name__ == "__main__":
    """
    Run from root. 
    python -m src.data.datasets.find_plus_dataset
    """
    test_num = 0

    if test_num == 0:
        for i in range(5):
            # Generate images and check whether the plusses are where they should be.
            dataset = FindThePlusDataset(fill_ratios=(0.5, 0.1, 0.01))
            fig, axs = plt.subplots(3, 3)
            fig.set_figheight(10)
            fig.set_figwidth(10)
            for i in range(9):
                # Sample example.
                sample_dict = dataset[i]

                # Plot the example.
                axs[i // 3, i % 3].imshow(X=sample_dict["img"])
                axs[i // 3, i % 3].set_title(f"label: {sample_dict['label']}, coords: {sample_dict['coords']}")
            plt.tight_layout()
            plt.show()
