import numpy as np
from scipy.stats import multivariate_normal
from sklearn.linear_model import LogisticRegression
import random

def Gaussian_Sample(s,mu_1,Sigma_1,mu_2,Sigma_2,sed=1):
    """
    :param s: size of dataset
    :param mu_1: first mean vector
    :param Sigma_1: first covariance matrix
    :param mu_2: second mean vector
    :param Sigma_2: second covariance matrix
    :param sed: random seed
    :return: the dataset for classification
    """
    np.random.seed(sed)
    x_real = np.random.multivariate_normal(mean=mu_1, cov=Sigma_1, size=s)
    x_syn = np.random.multivariate_normal(mean=mu_2, cov=Sigma_2, size=s)
    y_real = np.ones(s)
    y_syn = np.zeros(s)
    X_raw = np.concatenate([x_real,x_syn])
    Y_raw = np.concatenate([y_real,y_syn])
    return([X_raw,Y_raw])


def Psi_Trans(X):
    X_new = list(X)
    for i in range(len(X)):
        for j in np.arange(i,len(X),1):
            X_new.append(X[i]*X[j])
    return(X_new)


def dividing_train_test(x, y, s, train_size, test_size):
    train_ind = random.choices(range(0, s*2), k = train_size)
    x_train = x[train_ind]
    y_train = y[train_ind]

    test_ind = random.choices(range(0, s*2), k = test_size)
    x_test = x[test_ind]
    y_test = y[test_ind]
    return(x_train, y_train, x_test, y_test)


def Gaussian_Sample_noise(s,mu_1,Sigma_1,mu_2,Sigma_2, noise_var, sed=1):
    """
    :param s: size of dataset
    :param mu_1: first mean vector
    :param Sigma_1: first covariance matrix
    :param mu_2: second mean vector
    :param Sigma_2: second covariance matrix
    :param sed: random seed
    :return: the dataset for classification
    """
    np.random.seed(sed)
    m = len(mu_1)
    x_real = np.random.multivariate_normal(mean=mu_1, cov=Sigma_1, size=s)
    noise = np.random.normal(0, noise_var, (s,m))
    x_syn = np.random.multivariate_normal(mean=mu_2, cov=Sigma_2, size=s) + noise
    y_real = np.ones(s)
    y_syn = np.zeros(s)
    X_raw = np.concatenate([x_real,x_syn])
    Y_raw = np.concatenate([y_real,y_syn])
    return([X_raw,Y_raw])
