from typing import Callable, List, Optional

from PIL import Image
import numpy as np
import torch
import torchvision.datasets as datasets


"""
class CondColoredMNIST(datasets.MNIST):
    
    p = 0.2
    def __init__(
            self,
            root: str,
            train: bool = True,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            download: bool = False,
            digits: Optional[List[int]] = None
        ) -> None:
        super().__init__(
            root, 
            train=train,
            transform=transform, 
            target_transform=target_transform,
            download=download
        )
        
        self.digits = list(range(10)) if digits is None else digits
        data_filt, targets_filt = [], []
        for i, target in enumerate(self.targets):
            if target in self.digits:
                data_filt.append(self.data[i])
                targets_filt.append(self.digits.index(target.item()))
        self.data = torch.stack(data_filt)
        self.targets = torch.Tensor(targets_filt)
        self.n_classes = [len(self.digits), 2, 2]
    
    
    def __getitem__(self, index: int, binary_coloring=False):
        
        img, digit_label = self.data[index], int(self.targets[index])
        digit = self.digits[digit_label]

        rand_number = torch.rand(1)
        if digit in [0,1,2,3,4]:
            digit_rgb = torch.Tensor([1,0,0]) if rand_number>self.p else torch.Tensor([0,0,1])
        else:
            digit_rgb = torch.Tensor([0,0,1]) if rand_number>self.p else torch.Tensor([1,0,0])
        digit_target = int(digit_rgb[-1].item())

        rand_number = torch.rand(1)
        if digit in [0,2,4,6,8]:
            back_rgb = torch.Tensor([0,0,0]) if rand_number>self.p else torch.Tensor([1,1,1])
        else:
            back_rgb = torch.Tensor([1,1,1]) if rand_number>self.p else torch.Tensor([0,0,0])
        back_target = int(back_rgb[-1].item())

        if binary_coloring:
            img_digit_rgb = digit_rgb.view(-1,1,1) * 255*(img>0).repeat(3, 1, 1)
        else:
            img_digit_rgb = digit_rgb.view(-1,1,1) * img.repeat(3, 1, 1)
        img_back_rgb = back_rgb.view(-1,1,1) * 255*(img==0).repeat(3, 1, 1)
        
        img_rgb = img_digit_rgb + img_back_rgb
        targets = [digit_label, digit_target, back_target]

        img_rgb = Image.fromarray(
            np.transpose(img_rgb.numpy().astype(np.uint8), (1, 2, 0))
        )
        
        if self.transform is not None:
            img_rgb = self.transform(img_rgb)

        return img_rgb, targets
"""

class CondColoredMNIST(datasets.MNIST):

    def __init__(
            self,
            root: str,
            train: bool = True,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            download: bool = False,
            digits: Optional[List[int]] = [0,1,2,3],
            p: float=0.2
        ) -> None:
        super().__init__(
            root, 
            train=train,
            transform=transform, 
            target_transform=target_transform,
            download=download
        )
        
        self.p = p
        self.digits = list(range(10)) if digits is None else digits
        data_filt, targets_filt = [], []
        for i, target in enumerate(self.targets):
            if target in self.digits:
                data_filt.append(self.data[i])
                targets_filt.append(self.digits.index(target.item()))
        self.data = torch.stack(data_filt)
        self.targets = torch.Tensor(targets_filt)
        self.n_classes = [len(self.digits), 2, 2]
    
    
    def __getitem__(self, index: int, binary_coloring=False):
        
        img, digit_label = self.data[index], int(self.targets[index])
        digit = self.digits[digit_label]

        rand_number = torch.rand(1)
        if digit in [0,1]:
            digit_rgb = torch.Tensor([1,0,0]) if rand_number>self.p else torch.Tensor([0,1,0])
        elif digit in [2,3]:
            digit_rgb = torch.Tensor([0,1,0]) if rand_number>self.p else torch.Tensor([1,0,0])
        digit_target = int(digit_rgb[1].item())

        rand_number = torch.rand(1)
        if digit in [0,2]:
            back_rgb = torch.Tensor([0,0,0]) if rand_number>self.p else torch.Tensor([1,1,1])
        elif digit in [1,3]:
            back_rgb = torch.Tensor([1,1,1]) if rand_number>self.p else torch.Tensor([0,0,0])
        back_target = int(back_rgb[-1].item())

        if binary_coloring:
            img_digit_rgb = digit_rgb.view(-1,1,1) * 255*(img>0).repeat(3, 1, 1)
        else:
            img_digit_rgb = digit_rgb.view(-1,1,1) * img.repeat(3, 1, 1)
        img_back_rgb = back_rgb.view(-1,1,1) * 255*(img==0).repeat(3, 1, 1)
        
        img_rgb = img_digit_rgb + img_back_rgb
        targets = [digit_label, digit_target, back_target]

        img_rgb = Image.fromarray(
            np.transpose(img_rgb.numpy().astype(np.uint8), (1, 2, 0))
        )
        
        if self.transform is not None:
            img_rgb = self.transform(img_rgb)

        return img_rgb, targets