
import numpy as np
import matplotlib.pyplot as plt


def currupt_data(X, feature_id=0, method='MCAR', missing_rate=0.1, feature_id2=None):
    
    """
    
    :param X: n by p matrix of input data
    :param feature_id: feadture index to currupt, defaults to 0
    :param method: Curruption method, defaults to 'MCAR'
    :param missing_rate: 0-1 float for sample missing rate, defaults to 0.1
    :param feature_id2: Used for MAR curruption to dicede the dependent feature, 
    defaults to None
    :return: Xc, the currupted version of X

    """
    
    # Adopted from https://github.com/schelterlabs/jenga
    
    n, p = X.shape
    
    if method == 'MCAR':
        rows = np.random.permutation(range(n))[:int(n * missing_rate)]
        
    elif method == 'MAR' or method == 'MNAR':
        
        n_values_to_discard = int(n * min(missing_rate, 1.0))
        perc_lower_start = np.random.randint(0, n - n_values_to_discard)
        perc_idx = range(perc_lower_start, perc_lower_start + n_values_to_discard)
        
        if method == 'MNAR':
            
            rows = np.argsort(X[:,feature_id])[perc_idx]

        elif method == 'MAR':
            
            if feature_id2 is None:
                depends_on_col = np.random.choice(list(set(range(p)) - {feature_id}))
            else:
                depends_on_col = feature_id2
                
            rows = np.argsort(X[:,depends_on_col])[perc_idx]

    else:
        ValueError('sampling type not recognized')
    
    Xc = X.copy()
    Xc[rows, feature_id] = np.nan

    return Xc


def simulate_xor_data(n_samples=1000, noise=0.2, plot=False, s=10, alpha=1, 
                      axes=None, colors=['#CD5C5C', '#6495ED']):
    """
    
    :param n_samples: Number of simulated samples, defaults to 1000
    :param noise: Variance of added noise, defaults to 0.2
    :param plot: Flag to plot data, defaults to False
    :param s: Marker size in scatter pplot, defaults to 10
    :param alpha: alpha in the scatter plot, a value in 0-1, defaults to 1
    :param axes: axes pointer for the plot, defaults to None
    :param colors: Color of samples in each class in the scatter plot, 
    defaults to ['#CD5C5C', '#6495ED']
    :return: X,y Simulated data and targets

    """

    X = np.zeros([n_samples,2])
    X[int(np.floor(n_samples/4)):int(np.floor(n_samples/2)),0] = 1
    X[int(np.floor(n_samples/2)):int(np.floor(n_samples/4*3)),1] = 1
    X[int(np.floor(n_samples/4*3)):,:] = 1
    
    y = np.zeros([n_samples,])
    y[int(np.floor(n_samples/4)):int(np.floor(n_samples/4*3))] = 1
    
    X = X + np.random.randn(X.shape[0], X.shape[1]) * noise
    
    if plot:
        if axes is None:
            plt.scatter(X[y==0, 0], X[y==0, 1], s=s, alpha=alpha, c=colors[0])
            plt.scatter(X[y==1, 0], X[y==1, 1], s=s, alpha=alpha, c=colors[1])
        else:
            axes.scatter(X[y==0, 0], X[y==0, 1], s=s, alpha=alpha, c=colors[0])
            axes.scatter(X[y==1, 0], X[y==1, 1], s=s, alpha=alpha, c=colors[1])
    
    return X,y

