import os
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from scipy.stats import ortho_group
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

def get_name(s):
    r = ""
    for i in range(len(s)-2):
        r += str(s[i]).replace(".", "-") + "_"
    r += str(s[-2]).replace(".", "-") + "." + str(s[-1])
    return r

def get_data(cf):
    dataset = cf["data"]["dataset"]
    
    if dataset == "classification_gradient":
        train_data, test_data, Q, W, label_weight = generate_half_space_classify_data_from_teacher(cf)
        return train_data, test_data, Q, W, label_weight
    elif dataset == "MNIST":
        train_transform, test_transform = _data_transforms(dataset)
        train_data = torchvision.datasets.MNIST(
            root="./data", train=True, download=True, transform=train_transform)
        test_data = torchvision.datasets.MNIST(
            root="./data", train=False, download=True, transform=test_transform)
        train_data, label_weight = binary_selection(dataset, train_data)
        test_data, _ = binary_selection(dataset, test_data)
        return train_data, test_data, None, None, label_weight
    elif dataset == "CIFAR10":
        train_transform, test_transform = _data_transforms(dataset)
        train_data = torchvision.datasets.CIFAR10(
            root="./data", train=True, download=True, transform=train_transform)
        test_data = torchvision.datasets.CIFAR10(
            root="./data", train=False, download=True, transform=test_transform)
        train_data, label_weight = binary_selection(dataset, train_data)
        test_data, _ = binary_selection(dataset, test_data)
        return train_data, test_data, None, None, label_weight
    elif dataset == "SVHN":
        train_transform, test_transform = _data_transforms(dataset)
        train_data = torchvision.datasets.SVHN(
            root="./data", split='train', download=True, transform=train_transform)
        test_data = torchvision.datasets.SVHN(
            root="./data", split='test', download=True, transform=test_transform)
        train_data, label_weight = binary_selection(dataset, train_data)
        test_data, _ = binary_selection(dataset, test_data)
        return train_data, test_data, None, None, label_weight
    else:
        assert(False)

def binary_selection(data_class, dataset):
    if data_class == "SVHN":
        dataset.targets = dataset.labels

    if type(dataset.targets) != torch.Tensor:
        dataset.targets = torch.Tensor(np.array(dataset.targets)).type(torch.long)

    dataset.data = dataset.data[dataset.targets < 2]
    print(dataset.data.shape)
    dataset.targets = dataset.targets[dataset.targets < 2]
    
    if data_class == "SVHN":
        dataset.labels = dataset.targets

    label_ratio_1 = torch.sum(dataset.targets) / dataset.targets.shape[0]
    label_ratio_0 = 1 - label_ratio_1

    print("label ratio: ", label_ratio_0, label_ratio_1)
    label_weight = torch.Tensor([1 / label_ratio_0, 1 / label_ratio_1])
    return dataset, label_weight


def _data_transforms(dataset):

    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

    MNIST_MEAN = [0.1307,]
    MNIST_STD =  [0.3081,]

    SVHN_MEAN = [0.43768211,0.44376971,0.47280443]
    SVHN_STD = [0.19803012,0.20101562,0.19703614]

    if dataset == "CIFAR10":
        MEAN = CIFAR_MEAN
        STD = CIFAR_STD
    elif dataset == "MNIST":
        MEAN = MNIST_MEAN
        STD = MNIST_STD
    elif dataset == "SVHN":
        MEAN = SVHN_MEAN
        STD = SVHN_STD

    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD),
    ])

    return train_transform, test_transform

def process_data(data, USE_CUDA, cf):
    x,  label = data
    x = x.type(torch.float32)
    label = label.type(torch.int64)
    x = Variable(x)
    if len(x.shape) > 2:
        x = torch.reshape(x, (x.shape[0], -1))
    if cf["train"]["loss"] == "hinge":
        label= torch.unsqueeze(label, dim = 1)
    if USE_CUDA:
        x = x.cuda()
        label = label.cuda()
    return x, label

class Non_image_dataset(Dataset): 
    def __init__(self, data_input, data_label):
        self.input = data_input
        self.label = data_label
        self.length = len(data_label)

    def __getitem__(self, idx):
        return (self.input[idx], self.label[idx])

    def __len__(self):
        return self.length


def generate_half_space_classify_data_from_teacher(cf):
    distribution = cf["data"]["distribution"]
    data_size = cf["data"]["data_size"]
    dim_full = cf["data"]["dim_full"]
    dim_input = cf["data"]["dim_input"]
    feature_dim = cf["data"]["feature_dim"]
    effective_dim = cf["data"]["effective_dim"]
    label_function = cf["data"]["label_function"]
    structure = cf["data"]["structure"]

    print(structure)
    if structure:
        data_name = get_name(["cla", distribution, data_size,
                            dim_full, dim_input, feature_dim, feature_dim, "structure", "data"])
    else: 
        data_name = get_name(["cla", distribution, data_size,
                              dim_full, dim_input, feature_dim, feature_dim, "no_structure", "data"])
    data_path = os.path.join("data", data_name)
    print(data_path)
    if os.path.exists(data_path):
        dic = torch.load(data_path)
        train_input = dic["train_input"]
        train_label = dic["train_label"]
        test_input = dic["test_input"]
        test_label = dic["test_label"]
        Q = dic["Q"]
        W = dic["W"]
        label_weight = dic["label_weight"] 
    else:
        dic = {}
        Q = torch.Tensor(ortho_group.rvs(dim=dim_full)).numpy()

        W = Q[:feature_dim, :]

        u, s, vh = np.linalg.svd(W, full_matrices=False)
        for i in range(len(s)):
            if s[i] != 0 :
                s[i] = 1 / s[i]
        smat = np.diag(s)
        W_inverse = torch.Tensor(np.dot(vh.T, np.dot(smat, u.T)))
        
        if label_function == "parity":
            assert(effective_dim % 2 == 1)
            p1 = 0.5
            if structure:
                feature_train = torch.Tensor(np.concatenate((
                    np.random.binomial(1, p1, (data_size//2, feature_dim)), 
                    np.concatenate((np.ones((data_size // 4, effective_dim)), np.random.binomial(1, p1, (data_size//4, feature_dim - effective_dim))), axis=1),
                    np.concatenate((np.zeros((data_size // 4, effective_dim)), np.random.binomial(1, p1, (data_size//4, feature_dim - effective_dim))), axis=1)
                    ), axis=0).astype(np.float32))
                feature_test = torch.Tensor(np.concatenate((
                    np.random.binomial(1, p1, (data_size//2, feature_dim)), 
                    np.concatenate((np.ones((data_size // 4, effective_dim)), np.random.binomial(1, p1, (data_size//4, feature_dim - effective_dim))), axis=1),
                    np.concatenate((np.zeros((data_size // 4, effective_dim)), np.random.binomial(1, p1, (data_size//4, feature_dim - effective_dim))), axis=1)
                    ), axis=0).astype(np.float32))
            else:
                feature_train = torch.Tensor(np.random.binomial(1, p1, (data_size, feature_dim)).astype(np.float32))
                feature_test = torch.Tensor(np.random.binomial(1, p1, (data_size, feature_dim)).astype(np.float32))
            
            feature_train_prod = torch.prod(
                feature_train[:, :effective_dim] * 2 - 1, dim=1)
            train_label = np.where((feature_train_prod == 1), 1, 0)
            feature_test_prod = torch.prod(
                feature_test[:, :effective_dim] * 2 - 1, dim=1)
            test_label = np.where((feature_test_prod == 1), 1, 0)
        else:
            if structure:
                p1 = 0.667
            else:
                p1 = 0.5
            p2 = 0.5
            if label_function == "interval":
                feature_train = torch.Tensor(np.concatenate((
                    np.random.binomial(1, p1, (data_size, effective_dim)), np.random.binomial(1, p2, (data_size, feature_dim - effective_dim))), axis = 1).astype(np.float32))
                feature_test = torch.Tensor(np.concatenate((
                    np.random.binomial(1, p1, (data_size, effective_dim)), np.random.binomial(1, p2, (data_size, feature_dim - effective_dim))), axis=1).astype(np.float32))
            elif label_function == "all":
                feature_train = torch.Tensor(np.concatenate((np.concatenate(
                    (np.ones((data_size // 2, effective_dim)), np.zeros((data_size // 2, feature_dim - effective_dim))), axis=1), np.zeros((data_size // 2, feature_dim))), axis=0).astype(np.float32))
                feature_test = torch.Tensor(np.concatenate((np.concatenate(
                    (np.ones((data_size // 2, effective_dim)), np.zeros((data_size // 2, feature_dim - effective_dim))), axis=1), np.zeros((data_size // 2, feature_dim))), axis=0).astype(np.float32))
            else:
                assert(False)
            t1 = round(cf["data"]["effective_dim"] * p1)
            t2 = cf["data"]["effective_dim"]
            feature_train_sum = torch.sum(feature_train[:, :effective_dim], dim = 1)
            train_label =  np.where((feature_train_sum > t1) * (feature_train_sum <= t2), 1, 0)
            feature_test_sum = torch.sum(feature_test[:, :effective_dim], dim = 1)
            test_label =  np.where((feature_test_sum > t1) * (feature_test_sum <= t2), 1, 0)
        
        train_input = feature_train @ W_inverse.T # + noise_train
        train_input = train_input - torch.mean(train_input, dim=0)
        train_input = train_input / torch.unsqueeze(torch.norm(train_input, dim = 1), dim = 1)
        test_input = feature_test @ W_inverse.T # + noise_test
        test_input = test_input - torch.mean(test_input, dim=0)
        test_input = test_input / torch.unsqueeze(torch.norm(test_input, dim = 1), dim = 1)


        label_ratio_1 = torch.sum(torch.Tensor(train_label)) / train_label.shape[0]
        label_ratio_0 = 1 - label_ratio_1

        print("label ratio: ", label_ratio_0, label_ratio_1)
        label_weight = torch.Tensor([1 / label_ratio_0, 1 / label_ratio_1])

        dic["train_input"] = train_input
        dic["train_label"] = train_label
        dic["test_input"] = test_input
        dic["test_label"] = test_label
        dic["Q"] = Q
        dic["W"] = W
        dic["label_weight"] = label_weight
        torch.save(dic, data_path)

    return Non_image_dataset(train_input, train_label), Non_image_dataset(test_input, test_label), \
        torch.Tensor(Q), torch.Tensor(W), label_weight
