from torchvision import datasets, transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import random


### For mnist
def get_mnist_loaders():
    trainset = datasets.MNIST(root='../data', transform=transforms.ToTensor(), train=True)
    testset = datasets.MNIST(root='../data', transform=transforms.ToTensor(), train=False)

    train_loader, test_loader = [DataLoader(x, batch_size=64, shuffle=True) for x in [trainset, testset]]
    return train_loader, test_loader

def make_blue(batch):
    zeros = torch.zeros(batch.shape)
    blue_batch = torch.cat([zeros, zeros, batch], axis=1)
    return blue_batch

def make_red(batch):
    zeros = torch.zeros(batch.shape)
    red_batch = torch.cat([batch, zeros, zeros], axis=1)
    return red_batch

def insert_spur(batch, labels, flip_fraction=0.1, fn_1=make_blue, fn_2=make_red):
    ''' TODO: gotta rename vars to make this general '''
    color_blue = (labels < 5)
    for i in range(color_blue.shape[0]):
        if random.random() < flip_fraction:
            # print('FLIPPING ENTRY {} with label {}'.format(i, labels[i]))
            color_blue[i] = not color_blue[i]
    color_red = [not color_blue[i] for i in range(color_blue.shape[0])]

    zero_to_four = batch[color_blue]
    five_to_nine = batch[color_red]
    zero_to_four = fn_1(zero_to_four)
    five_to_nine = fn_2(five_to_nine)
    
    colored_batch = torch.zeros((batch.shape[0], 3, batch.shape[2], batch.shape[3]))
    colored_batch[color_blue] = zero_to_four
    colored_batch[color_red] = five_to_nine
    return colored_batch

def get_cifar_loaders():
    trainset = datasets.CIFAR10(root='../data', transform=transforms.ToTensor(), train=True)
    testset = datasets.CIFAR10(root='../data', transform=transforms.ToTensor(), train=False)
    
    train_loader, test_loader = [DataLoader(x, batch_size=64, num_workers=4, pin_memory=True, shuffle=True) for x in [trainset, testset]]
    return train_loader, test_loader

def alter_lighting(batch, scale=1.25):
    lighter = batch * scale
    lighter = torch.clamp(lighter, 0,1)
    return lighter

def shift_color(batch, scale=0.25, channel_dim=0):
    shift = torch.zeros(batch.shape)
    shift[:,channel_dim] = torch.ones(batch.shape[0], batch.shape[2], batch.shape[3]) * scale
    shifted = batch + shift
    shifted = torch.clamp(shifted, 0,1)
    return shifted

if __name__ == '__main__':
    train_loader,_ = get_loaders()
    batch, labels = next(iter(train_loader))
    colored_batch = insert_color(batch, labels)
    grid = make_grid(colored_batch, nrow=8)
    plt.imshow(grid.numpy().swapaxes(0,1).swapaxes(1,2))
    plt.savefig('example.png')

