import numpy as np
from numpy import linalg as LA
from sklearn.datasets import load_boston
import pickle as pkl
import pandas as pd
from sklearn.linear_model import Ridge


def extract_features(dataset):
    """Extract the covariates and the Y-vectors corresponding to treatment and control values for each dataset."""
    if dataset == 'boston':
        X, y = pkl.load(open("../datasets/boston_dataset.pkl", "rb"))
        n, d = X.shape
        Y1 = Y0 = y
    elif dataset == 'lalonde':
        X, y = lalonde()
        n, d = X.shape
        Y1 = Y0 = y
    elif dataset == 'ihdp':
        X, Y1, Y0 = ihdp()
        n, d = X.shape
    elif dataset == 'twins':
        X, Y1, Y0 = twins()
        n, d = X.shape

    # Row normalize the matrix
    norm_max = 0.0
    for i in range(0, n):
        norm_max = max(norm_max, LA.norm(X[i]))

    for i in range(0, n):
        X[i] = (float(1.0) / norm_max) * X[i]

    return X, Y0, Y1


def generate_dataset(dataset='boston'):
    """Extract features from a dataset"""
    X, Y1, Y0 = extract_features(dataset)
    return X, Y1, Y0


def ihdp():
    ihdp_dataset = pd.read_csv("../datasets/ihdp_npci_1.csv", header=None)
    col = ["treatment", "y_factual", "y_cfactual", "mu0", "mu1", ]

    for i in range(1, 26):
        col.append("x" + str(i))
    ihdp_dataset.columns = col
    ihdp_dataset.head()

    labelled_data = ihdp_dataset.to_numpy()
    X = labelled_data[:, 5:]
    Y1 = labelled_data[:, 4]
    Y0 = labelled_data[:, 3]

    return X, Y0, Y1


def lalonde():
    data = pd.read_csv('../datasets/lalonde.csv')
    labelled_data = data.to_numpy()

    _, d = labelled_data.shape
    cols = [True for i in range(0, d)]
    cols[8], cols[11] = False, False
    X = labelled_data[:, np.array(cols)]

    cols = [False for i in range(0, d)]
    cols[8] = True
    y = labelled_data[:, np.array(cols)]
    y = np.array(y[:, 0])

    return X, y


def twins():
    data = pd.read_csv('../datasets/twins.csv')
    label = pd.read_csv('../datasets/twins_labels.csv')
    all_data = pd.concat([data,label], axis=1)
    labelled_data = data.dropna(axis = 0, how ='any').to_numpy()

    X = labelled_data[:, 0:-2]
    y0 = labelled_data[:, -2]
    y1 = labelled_data[:, -1]

    return X, y0, y1


def synthetic_ate_data_1(num_samples = 10000, num_covariates = 50, std = 0.2, interval = 5):
    X = np.random.rand(num_samples, num_covariates)
    #X = X - np.sum(X, axis = 0) / num_samples
    s = np.random.rand(num_samples) * 1000
    for i in range(num_samples):
        X[i, :] = X[i, :] * s[i]
    b = np.random.rand(num_covariates, 1)

    mu = np.matmul(X, b) / 100 + np.random.normal(0, std, (num_samples, 1))

    y0 = np.random.rand(num_samples, 1) * interval;
    y1 = - y0 + mu

    return X, np.squeeze(y0), np.squeeze(y1)


def synthetic_ite_data_1(num_samples = 10000, num_covariates = 50, std = 0.2, interval = 5):
    X = np.random.rand(num_samples, num_covariates)
    s = np.random.rand(num_samples) * 100
    for i in range(num_samples):
        X[i, :] = X[i, :] * s[i]
    b = np.random.rand(num_covariates, 1)

    atevec = np.matmul(X, b) / 100 + np.random.normal(0, std, (num_samples, 1))

    y0 = np.random.rand(num_samples, 1) * interval;
    y1 = y0 + atevec

    return X, np.squeeze(y0), np.squeeze(y1)
    