import torch
import numpy as np

import torchvision
import torchvision.transforms as transforms


w = h = 28
c = 1

DATA_DESC = {
    'data': 'gaussian_teacher',
    #'classes': ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'),
    #'num_classes': 10,
    'w':w,
    'h':h,
    'd':c
}


def load_gaussian_teacher(teacher_model, num_classes = 10, train_size = 10000, test_size = 2000, train_means = None, train_covariances=None, test_means = None, test_covariances=None, use_augmentation=False, validation = False):
    """
    Returns CIFAR10 train, test datasets and dataloaders.
    Arguments:
        data_dir (str): path to data directory.
        use_augmentation (bool): whether to use augmentations for training set.
    Returns:
        train dataset, test dataset. 
    """
    DATA_DESC['num_classes'] = num_classes
    train_mus = train_means
    train_As = torch.cholesky(train_covariances)
    test_mus = test_means
    test_As = torch.cholesky(test_covariances)
    X_train = torch.zeros((train_size,w*h*c)).cpu()
    X_test = torch.zeros((test_size,w*h*c)).cpu()
    y_train = torch.zeros(train_size).cpu().long()
    y_test = torch.zeros(test_size).cpu().long()
    for i in range(num_classes):
        X_train[i*train_size//num_classes:(i+1)*train_size//num_classes,:] = train_mus[i,:] + torch.randn((train_size//num_classes,w*h*c)).cpu()@train_As[i,:]
        X_test[i*test_size//num_classes:(i+1)*test_size//num_classes,:] = test_mus[i,:] + torch.randn((test_size//num_classes,w*h*c)).cpu()@test_As[i,:]

        y_train[i*train_size//num_classes:(i+1)*train_size//num_classes] = i*torch.ones(train_size//num_classes).cpu().long()
        y_test[i*test_size//num_classes:(i+1)*test_size//num_classes] = i*torch.ones(test_size//num_classes).cpu().long()
    # X_train = torch.randn((train_size,w*h*c)).cpu()
    # X_test = torch.randn((test_size,w*h*c)).cpu()

    # y_train = torch.argmax(teacher_model(X_train), dim=1).cpu()
    # y_test = torch.argmax(teacher_model(X_test), dim=1).cpu()
    #print(y_test)
    train_dataset = torch.utils.data.TensorDataset(X_train,y_train)
    test_dataset = torch.utils.data.TensorDataset(X_test,y_test)
    if validation:
        dataset_size = len(train_dataset)
        indices = list(range(dataset_size))
        split = 1024
        np.random.shuffle(indices)
        train_indices, val_indices = indices[split:], indices[:split]

        # Creating PT data samplers and loaders:
        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indices)
        val_sampler = torch.utils.data.sampler.SubsetRandomSampler(val_indices)
        return train_dataset, test_dataset, train_sampler, val_sampler
    return train_dataset, test_dataset