import torch
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader
from synthetic_image_data import SyntheticImageDataset

p = 97
eq_token = p
op_token = p + 1

def division_mod_p_data(p, eq_token, op_token, include_ops=False, shuffle=True, test_size=0.5, random_seed=42, seperate_labels=True):
    """
    x◦y = x/y (mod p) for 0 ≤ x < p, 0 < y < p
    """
    x = torch.arange(p)
    y = torch.arange(1, p)
    x, y = torch.cartesian_prod(x, y).T

    eq = torch.ones_like(x) * eq_token
    op = torch.ones_like(x) * op_token
    result = x * y % p

    # "All of our experiments used a small transformer trained on datasets of
    # equations of the form a◦b = c, where each of “a”, “◦”, “b”, “=”, and “c”
    # is a seperate token"
    if include_ops:
        dset_tensor = torch.stack([x, op, y, eq, result])
    else:
        dset_tensor = torch.stack([x, y, result])
    X_train, X_test = train_test_split(dset_tensor.T, test_size=test_size, shuffle=shuffle, random_state=random_seed)
    y_train, y_test = None, None
    if seperate_labels:
        y_train, y_test = X_train[:,-1], X_test[:,-1]
        X_train, X_test = X_train[:,0:-1], X_test[:,0:-1]
    return X_train, X_test, y_train, y_test


def make_sklearn_classification(*args, shuffle=True, test_size=0.5, random_seed=42, **kwargs, ):
    X,y = make_classification(*args, **kwargs)
    X, y  = torch.tensor(X, dtype=torch.float), torch.tensor(y)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, shuffle=shuffle, random_state=random_seed)
    return X_train, X_test, y_train, y_test

def get_mnist(*args, shuffle=True, test_size=0.5, random_seed=42, num_samples=None, flatten=False, **kwargs):
    '''
    Probably a much better way to do this, but this keeps consistency with other methods. Note that we don't use the
    MNIST test set, we just make a test set from the train split.
    :return:
    '''
    if flatten:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),# Convert PIL image to Tensor
            transforms.Lambda(lambda x: x.view(-1))  # Flatten the tensor
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
    # Download and load the training dataset
    mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

    train_loader = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=len(mnist_train), shuffle=False)
    train_data = next(iter(train_loader))
    train_images, train_labels = train_data
    if num_samples is not None:
        train_labels = train_labels[:num_samples]
        train_images = train_images[:num_samples]
    X_train, X_test, y_train, y_test = train_test_split(train_images, train_labels, test_size=test_size, shuffle=shuffle,
                                                        random_state=random_seed)
    return X_train, X_test, y_train, y_test


def get_synthetic_images(*args, width, height, num_images, test_size=0.5, label_noise=0.0,
                         class_weights=None, flatten=True, **kwargs):
    generator = SyntheticImageDataset(width, height, num_images)
    X_train, X_test, y_train, y_test = generator.generate_dataset(test_split=test_size, label_noise=label_noise,
                                                                  class_weights=class_weights, flatten=flatten,
                                                                  *args, **kwargs)
    return X_train, X_test, y_train, y_test


def get_cifar10(*args, shuffle=True, test_size=0.5, random_seed=42, num_samples=None, flatten=False, **kwargs):
    '''
    Loads the CIFAR-10 dataset, flattens the images into vectors, and splits the training data into train and test sets.
    Note: The CIFAR-10 test set is not used; instead, a split is made from the training set.
    :return: X_train, X_test, y_train, y_test
    '''
    if flatten:
        transform = transforms.Compose([
            transforms.ToTensor(),  # Convert PIL image to Tensor
            transforms.Lambda(lambda x: x.view(-1))  # Flatten the tensor into a vector
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])

    # Download and load the CIFAR-10 training dataset
    cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

    train_loader = torch.utils.data.DataLoader(dataset=cifar10_train, batch_size=len(cifar10_train), shuffle=False)
    train_data = next(iter(train_loader))
    train_images, train_labels = train_data

    if num_samples is not None:
        train_images = train_images[:num_samples]
        train_labels = train_labels[:num_samples]

    X_train, X_test, y_train, y_test = train_test_split(
        train_images, train_labels, test_size=test_size, shuffle=shuffle, random_state=random_seed
    )

    return X_train, X_test, y_train, y_test


def make_tensor_datasets(X_train, X_test, y_train, y_test):
    if y_train is not None:
        train_dset = TensorDataset(X_train, y_train)
        test_dset = TensorDataset(X_test, y_test)
    else:
        train_dset = TensorDataset(X_train)
        test_dset = TensorDataset(X_test)
    return train_dset, test_dset

def make_dataloaders(train_data, test_data, *args, **kwargs):
    train_dloader = DataLoader(train_data, *args, **kwargs)
    test_dataloader = DataLoader(test_data, *args, **kwargs)
    return train_dloader, test_dataloader


DATASET_DICT = {"synthetic_classification":make_sklearn_classification, "div_mod_p":division_mod_p_data,
                "MNIST":get_mnist, "synth_images":get_synthetic_images, "CIFAR":get_cifar10}