import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import time
import pickle
import argparse

from sklearn.linear_model import Lasso
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler

def load_mnist():
    with open("data/mnist.pkl",'rb') as f:
        mnist = pickle.load(f)
    return mnist["training_images"], mnist["training_labels"], mnist["test_images"], mnist["test_labels"]

def poisoned_lasso(alpha, target, poison_rows, X, y, heuristic_dot_prod=None, heuristic=0): 
    # Generate poison vectors
    X_poison = np.zeros(shape=(poison_rows, len(X[0])))
    if(heuristic == 0):
        X_poison[:, target] += 1 # switch target column to 1's
    else:
        if(heuristic_dot_prod < 0):
            X_poison[:, target] += -1 # switch target column to -1's
        else:
            X_poison[:, target] += 1 # switch target column to 1's
    
    y_poison = np.ones(poison_rows)        

    # Poisoned datasets
    X_poisoned = np.vstack([X, X_poison])
    y_poisoned = np.append(y, y_poison)
    
    # Alpha depends on the number of extra poison rows you add
    # alpha = lambd/(len(X) + poison_rows)
    
    lasso = Lasso(alpha=alpha)
    poisoned_lasso = lasso.fit(X_poisoned,y_poisoned)
    return poisoned_lasso.coef_[target], poisoned_lasso.coef_

def residue_heuristic(X, y, lasso_coef, support, seed, partial_info_pct=1.0):
    # Restrict the adversary to a portion of the dataset (partial information) 
    np.random.seed(seed)
    idx = np.random.choice(np.arange(len(y)), int(partial_info_pct * len(y)), replace=False)
    X_sample = X[idx]
    y_sample = y[idx]
        
    # Calculate dot product residue for each feature in the dataset
    V = y_sample - np.matmul(X_sample, lasso_coef)
    dot_prods = []
    for feature in range(X.shape[1]):
        col = X_sample[:, feature]
        dot_prods.append(np.dot(col, V))
    abs_dot_prods = [abs(ele) for ele in dot_prods]
        
    # Find the first feature that is most correlated that is NOT part of the support
    sorted_dot_prod_indices = np.flip(np.argsort(abs_dot_prods))
    for idx in sorted_dot_prod_indices:
        if idx not in support:
            heuristic_idx = idx
            break
    heuristic_dot_prod = dot_prods[heuristic_idx]
    heuristic_tuple = (heuristic_idx, heuristic_dot_prod)
    return heuristic_tuple

def load_dataset(dataset_name):
    if(dataset_name == "boston"):
        X, y = load_boston(return_X_y=True)
        alpha = 0.1
    elif(dataset_name == "SMK"):
        mat_file = sio.loadmat('data/SMK_CAN_187.mat')
        X = mat_file['X']
        y = mat_file['Y']
        y = y[:, 0]
        alpha = 0.05
    elif(dataset_name == "TOX"):
        mat_file = sio.loadmat('data/TOX_171.mat')
        X = mat_file['X']
        y = mat_file['Y']
        y = y[:, 0]
        alpha = 0.1
    elif(dataset_name == "PROSTATE"):
        mat_file = sio.loadmat('data/Prostate_GE.mat')
        X = mat_file['X']
        y = mat_file['Y']
        y = y[:, 0]
        alpha = 0.05
    elif(dataset_name == "MNIST"):
        _, _, X, y = load_mnist() # last two returns are training set (60k), so we take the smaller one
        print(X.shape)
        print(y.shape)
        alpha = 0.1
    elif(dataset_name == "SYNTHETIC"):
        n = 300
        p = 500000
        s = 5
        sigma = 0.5
        sigma_X = 1
        X, y, alpha = synthetic_data(n, p, s, sigma, sigma_X)

    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    y_mean = np.mean(y)
    y_std = np.std(y)
    y = (y - y_mean)/y_std
    return X, y, alpha

def synthetic_data(n, p, s, sigma, sigma_X, seed=42):
    np.random.seed(seed)

    W = np.random.normal(0, sigma, n)
    theta_nonzero = np.random.uniform(0.75, 1, s)
    theta_zero = np.zeros(p-s)
    theta = np.concatenate((theta_nonzero, theta_zero))
    X = np.random.normal(0, sigma_X, (n, p))
    X[:, p - 1] = 0
    y = X @ theta + W

    lambd = 2 * sigma * (np.sqrt(n * np.log(p)))
    print("Lambda={}".format(lambd))
    alpha = lambd/n 
    return X, y, alpha

if __name__ == "__main__":
    # Parse arguments from the command line
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", help="Name of the dataset to do.")
    parser.add_argument("--num_trials", help="Number of trials for adversary.")
    parser.add_argument("--partial_pct", help="Percent of data the adversary knows.")

    args = parser.parse_args()
    if args.dataset:
        dataset = args.dataset
    else:
        print("Dataset unspecified. Please specify with --dataset.")
    if args.num_trials:
        num_trials = int(args.num_trials)
    else:
        num_trials = 5
    if args.partial_pct and int(args.partial_pct) >= 0 and int(args.partial_pct) <= 100:
        partial_pct = float(args.partial_pct) / 100
    else:
        partial_pct = 1.0

    # Load the data
    X, y, alpha = load_dataset(dataset)

    # Find the support and target set
    lasso = Lasso(alpha=alpha) # default: alpha = 1.0
    lasso_coef = lasso.fit(X,y).coef_
    support = np.where(abs(lasso_coef) > 1e-6)[0]
    target_set = np.where(abs(lasso_coef) <= 1e-6)[0]
    print("Number of features in Supp: {}".format(len(support)))
    print("Number of target features: {}".format(len(target_set)))

    # Set the percentage of the dataset the adversary can access (partial information)
    print(X.shape)
    print(y.shape)
    heuristic_tuples = []
    for seed in range(num_trials):
        heuristic_tuples.append(residue_heuristic(X, y, lasso_coef, support, seed, partial_pct))

    # MAIN LOOP:
    # Find necessary poison rows for TARGETED feature (the one we found with the heuristic)
    # We have <num_trial> samples of the adversary's randomly chosen "partial" dataset. 
    K_RANGE = 1000
    k_values = list(range(K_RANGE))
    heuristic_k_dict = dict()

    for heuristic_tuple in heuristic_tuples:
        target = heuristic_tuple[0]
        heuristic_dot_prod = heuristic_tuple[1]
        first = 0
        last = K_RANGE - 1
        found = False
        while(first <= last and not found):
            mid = (first + last)//2
            print(mid)
            coef, _ = poisoned_lasso(alpha, target, mid, X, y, heuristic_dot_prod, heuristic=1)
            if(abs(coef) > 1e-6):
                prev_coef, _ = poisoned_lasso(alpha, target, mid - 1, X, y, heuristic_dot_prod, heuristic=1)
                if(abs(prev_coef) <= 1e-6):
                    found = True
                    print("FOUND! {}".format(mid))
                    heuristic_k_dict[target] = mid
                else:
                    last = mid - 1
            else:
                first = mid + 1

    # Combine heuristic_k_dict and existing data (k dict) into one dict
    path_prefix = "partial_info_data/" + dataset + "/"
    for key in heuristic_k_dict:
        # Open the existing k dict data from previous experiments
        sorted_target_k_dict = pickle.load(open("k_dicts/{}_dict.pkl".format(dataset), "rb"))
        sorted_target_k_dict[key] = heuristic_k_dict[key]
        combined_k_dict = {k: v for k, v in sorted(sorted_target_k_dict.items(), key=lambda item: item[1])} # Resort
        with open(path_prefix + "{}pct_feature_{}.pkl".format(partial_pct, key), 'wb') as file:
            pickle.dump(combined_k_dict, file)
