from sklearn.metrics import accuracy_score
from sklearn.svm import SVC
from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression
from sklearn import datasets
from sklearn import preprocessing
import numpy as np

import shutil
import sys
import os
import time
import glob

import matplotlib.pyplot as plt
from dataset import read_dataset, save_dataset
import argparse
from utils import ensure_dir

from shuffle import shuffle_rotate

def plot_2d_discarded_indices(X, Y, discarded_indices, folder):
    plt.figure()
                
    colors = ['r', 'b']
    for i in range(len(Y)):
        plt.scatter(X[i, 0], X[i, 1], c = colors[int((Y[i]+ 1)/2)])
    for i in range(len(discarded_indices)):
        plt.scatter(X[discarded_indices[i], 0], X[discarded_indices[i], 1], c = 'g', s= 100)
    plt.savefig(folder + "/disc_ind_figure.pdf", dpi=150)

# main
parser = argparse.ArgumentParser(description='Combine discard points.')
parser.add_argument('datasetdir')
parser.add_argument('runid')
parser.add_argument('-d', '--inputdir', default='./jobs_out')
parser.add_argument('-o', '--outputdir', default='./reduce_out')
parser.add_argument('-fig', '--figuresdir', default='./figures')
parser.add_argument('-v', '--verbose', action='store_true', default=False)

args = parser.parse_args()
inputdir = args.inputdir + "/" + args.runid
outputdir = args.outputdir + "/" + args.runid
figuresdir = args.figuresdir + "/" + args.runid

ensure_dir(outputdir)
ensure_dir(figuresdir)

X, Y, Y1 = read_dataset(args.datasetdir)

n = X.shape[0]

discarded_points = np.array([], dtype=int)

for file in glob.glob(inputdir + "/discarded_points_*csv"):
    print("processing file ", file)
    tmp = np.fromfile(file, dtype=int, sep=',')
    discarded_points = np.append(discarded_points, tmp)

print("discarded_points", discarded_points)
print(f"\n Discarded {len(discarded_points)} points out of {n}, {len(discarded_points)/n}%")

if X.shape[1] == 2:
    plot_2d_discarded_indices(X, Y, discarded_points, figuresdir)

np.array(discarded_points).tofile(
    outputdir + '/all_discarded_points.csv',
    sep = ',')

left_points = list(set(range(n)) - set(discarded_points))

X_sort_1 = X[discarded_points, :]
X_sort_2 = X[left_points, :]

Y_sort_1 = Y[discarded_points]
Y_sort_2 = Y[left_points]

Y1_sort_1 = Y1[discarded_points]
Y1_sort_2 = Y1[left_points]

X_sort_3, Y_sort_3, Y1_sort_3 = shuffle_rotate(X_sort_2, Y_sort_2, Y1_sort_2)

Y1_sort = np.concatenate((Y1_sort_1, Y1_sort_3), axis = 0)
Y_sort = np.concatenate((Y_sort_1, Y_sort_3), axis = 0)
X_sort = np.concatenate((X_sort_1, X_sort_3), axis = 0)

save_dataset(outputdir, X_sort, Y_sort, Y1_sort)