import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from datasets import modular_arithmetic
from math import ceil

from torch.utils.data import TensorDataset, Subset

cf10_mlp_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.Lambda(lambda x: torch.flatten(x))
])

cf10_cnn_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

mnist_mlp_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: torch.flatten(x))
])

def load_dataset(args):
    if args.dataset == 'modular_arithmetic':
        X, y = modular_arithmetic.operation_mod_p_data(args.operation, args.prime)
        X_tr, y_tr, X_te, y_te = modular_arithmetic.make_data_splits(X, y, args.training_fraction)
        X_tr = F.one_hot(X_tr, args.prime).view(-1, 2*args.prime).float()
        y_tr_onehot = F.one_hot(y_tr, args.prime).float()
        X_te = F.one_hot(X_te, args.prime).view(-1, 2*args.prime).float()
        y_te_onehot = F.one_hot(y_te, args.prime).float()

        train = TensorDataset(X_tr, y_tr_onehot)
        test = TensorDataset(X_te, y_te_onehot)

        inp_dim = 2*args.prime
        out_dim = args.prime
    elif args.dataset == 'cifar10' or args.dataset == 'cifar10_subsample':
        if args.model == 'CNN':
            transform = cf10_cnn_transform
        else:
            transform = cf10_mlp_transform

        train = torchvision.datasets.CIFAR10(root=args.data_root, train=True,
                                             download=True, transform=transform,
                                             target_transform=lambda x: F.one_hot(torch.tensor(x), 10))
        test = torchvision.datasets.CIFAR10(root=args.data_root, train=False,
                                            download=True, transform=transform,
                                            target_transform=lambda x: F.one_hot(torch.tensor(x), 10))

        if args.dataset == 'cifar10_subsample':
            rand_idx = torch.randperm(len(train))[:1000]
            train = Subset(train, rand_idx)

        if args.model == 'CNN':
            inp_dim = 3
        else:
            inp_dim = 3*32*32
        out_dim = 10
    elif args.dataset == 'binary_cf10_single_logit':
        if args.loss == 'xent':
            tg_tf = lambda x: 0 if x < 5 else 1
        else:
            tg_tf = lambda x: -1 if x < 5 else 1

        train = torchvision.datasets.CIFAR10(root=args.data_root, train=True,
                                             download=True, transform=cf10_mlp_transform,
                                             target_transform=tg_tf)
        test = torchvision.datasets.CIFAR10(root=args.data_root, train=False,
                                            download=True, transform=cf10_mlp_transform,
                                            target_transform=tg_tf)
        inp_dim = 3*32*32
        out_dim = 1
    elif args.dataset == 'thm2_synthetic':
        X_tr, y_tr, X_te, y_te = thm2_synthetic(1000, 10000, 0.1, 1.0)
        train = TensorDataset(X_tr, y_tr)
        test = TensorDataset(X_te, y_te)

        inp_dim = 4
        out_dim = 1
    elif args.dataset == 'svhn':
        train = torchvision.datasets.SVHN(root=args.data_root, split='train',
                                          transform=cf10_mlp_transform,
                                          target_transform=None)
        # to-do -- isolate to some subet of digits
        raise
    elif args.dataset == 'mnist':
        train = torchvision.datasets.MNIST(root=args.data_root, train=True,
                                           transform=mnist_mlp_transform,
                                           download=True,
                                           target_transform=lambda x: F.one_hot(torch.tensor(x), 10))
        test = torchvision.datasets.MNIST(root=args.data_root, train=False,
                                          transform=mnist_mlp_transform,
                                          download=True,
                                          target_transform=lambda x: F.one_hot(torch.tensor(x), 10))
        # train.data = train.data[:5000]
        # train.targets = train.targets[:5000]

        inp_dim = 28*28
        out_dim = 10
    else:
        raise

    return train, test, inp_dim, out_dim

def thm2_synthetic(n_train, n_test, p, t):
    x_static = torch.tensor([1, 1, 0, 0]).float().unsqueeze(0)
    samp = torch.randint(low=0, high=3, size=(int(n_train*p), 4))
    X_tr = torch.cat([x_static.repeat(int(n_train*(1 - p)), 1), samp], dim=0)

    samp2 = torch.randint(low=0, high=3, size=(int(n_test*p), 4))
    X_te = torch.cat([x_static.repeat(int(n_test*(1 - p)), 1), samp2], dim=0)

    y_tr = t*X_tr[:,0]*X_tr[:,1] + X_tr[:,2]*X_tr[:,3]
    y_te = t*X_te[:,0]*X_te[:,1] + X_te[:,2]*X_te[:,3]

    return X_tr, y_tr, X_te, y_te

def make_dataloader(dataset, batch_size, shuffle=False, drop_last=False, num_workers=0):
    batch_size = min(batch_size, ceil(len(dataset) / 2))
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                       shuffle=shuffle, drop_last=drop_last,
                                       num_workers=num_workers)
