"""Functions for generating data."""

import numpy as np

from scipy.optimize import linear_sum_assignment
from scipy.special import expit, logit
from scipy.stats import bernoulli

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression


def generate_data(degree=2, theta_true=np.array([5,1,-1,2,2,-1]), cov_dist='gaussian',
                       cov_mean=0, cov_sigma=1, noise_sigma=1, pol_type='logistic',
                       num_samples = 10, pol_theta=1, pol_b = 0.5, test_data=False):
    '''
    
    Generate dataset 
    D1 = (W, Z, Y), size N1
    
    We will use a causal model that has:

    one covariate W in R^d d=1;
    treatment Z = 0 or 1;
    outcome c(W, Z) ---smaller the better--- is a polynomial of W and Z

    Y = c(W, Z) = theta_1 + theta_2* W + theta_3 * Z + theta_4 * W^2 + theta_5 * W * Z  + theta_6 * Z^2

    We will use polynomial regression feature generator to generate all nonlinear terms
    up to a certain degree.


    Policy \pi(W): maps from W to a probability of giving a treatment t.
    Thresholding rule or parameterized (e.g. logistic function)
    '''
    
    W_samples = generate_W_sample(dist=cov_dist, mean=cov_mean, sigma=cov_sigma, 
                                       num_samples=num_samples)
    
    if test_data==True:
        return W_samples
    
    Z_samples = generate_Z_sample(pol_type=pol_type, W_samples=W_samples, 
                                  theta=pol_theta, b=pol_b)
    
    
    Y_samples = generate_outcome(Ws=W_samples, Zs=Z_samples, theta=theta_true, 
                                 noise_sigma=noise_sigma, degree=degree)
    
    
    return W_samples, Z_samples, Y_samples


def threshold_pol(w, threshold=0.3):
    
    '''
    a threshold policy that always treat the individual if covariate >= threshold.
    
    '''
    
    if w >= threshold:
        return 1
    
    else:
        return 0

def logistic_pol(w, theta=1, b=0.5):
    
    '''
    a logistic policy
    
    '''
    
    prob = expit(theta*w+b)
    Z = np.random.binomial(1, prob)
    
    return Z

def generate_Z_sample(pol_type='logistic', W_samples=None, theta=1, b=0.5, threshold=0.3):
    
    Z_samples = []
    
    for W_sample in W_samples:
        
        if pol_type == 'logistic':
            
            Z_samples.append(logistic_pol(W_sample, theta=theta, b=b))
        
        elif pol_type == 'threshold':
            
            Z_samples.append(threshold_pol(W_sample, threshold=threshold))
            
        else: 
            print('Policy type not implemented')
            
    return np.array(Z_samples)

def generate_W_sample(dist='gaussian', mean=0, sigma=1, num_samples=100):
    
    if dist=='gaussian':
        samples = np.random.normal(mean, sigma, num_samples)
        
    else:
        print('type of distribution not implemented')
        
    return samples

def generate_outcome(Ws=None, Zs=None, theta=None, noise_sigma=1, degree=2):
    
    num_samples = Ws.shape[0]
    
    W_Z_samples = np.column_stack((Ws,Zs))
    
    poly = PolynomialFeatures(degree=degree)
    features = poly.fit_transform(W_Z_samples)
    
    outcomes = np.dot(features, theta.reshape((-1, 1)))
    noise = np.random.normal(0, noise_sigma, Ws.shape[0])
    
    outcomes = np.add(outcomes.T[0], noise)  
    outcomes[outcomes < 0] = 0
    
    return outcomes