import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Sampler

def genDataLinear(n):
    """
    :param n: number of data
    x ~ N(0,1)
    y ~ x1 + 0.5x2 -x3 + 0.5x4 + N(0,1)
    """
    X = torch.randn(n, 4)
    Y = torch.randn(n, 1)
    for i in range(n):
        Y[i, 0] += X[i, 0] + 0.5 * X[i, 1] - X[i, 2] - 0.5 * X[i, 3]

    return X, Y

class Dataset_from_matrix(Dataset):
    def __init__(self, data_matrix):
        """
        Args: create a torch dataset from a tensor data_matrix with size n * p
        [treatment, features, outcome]
        """
        self.data_matrix = data_matrix
        self.num_data = data_matrix.shape[0]

    def __len__(self):
        return self.num_data

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = self.data_matrix[idx, :]
        return (sample[0:-1], sample[-1])

def get_iter(data_matrix, batch_size, shuffle=True):
    dataset = Dataset_from_matrix(data_matrix)
    iterator = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return iterator

class BootstrapSampler(Sampler):
    def __init__(self, data_source, context):
        self.data_source = data_source
        self.context = context

    def __iter__(self):
        return iter(self.context)

    def __len__(self):
        return len(self.data_source)



from scipy.stats import wasserstein_distance
class distDiscrepancy:
    def __init__(self):
        pass
    def wasserstain(self, u_value, v_value, u_weight=None, v_weight=None):
        u_value = u_value.data.numpy()
        v_value = v_value.data.numpy()
        if u_weight is not None:
            u_weight = u_weight.data.numpy()
        if v_weight is not None:
            v_weight = v_weight.data.numpy()

        dim = u_value.shape[0]
        if dim != v_value.shape[0]:
            return "dim wrong"

        discrepancy = 0
        for i in range(dim):
            discrepancy += wasserstein_distance(u_value[i, :], v_value[i, :], u_weights=u_weight, v_weights=v_weight)

        return discrepancy

