from sklearn import datasets
from sklearn import preprocessing
import numpy as np
from gmpy2 import mpz
from sklearn.metrics import accuracy_score
from sklearn.svm import SVC
from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression

from compute_patterns import compute_patterns_bin_fixed, compute_diversity

from dataset import read_dataset
import argparse
import time
import pickle
from os import path
from utils import ensure_dir

parser = argparse.ArgumentParser(description='Computes patterns')
parser.add_argument('runid')
parser.add_argument('-d', '--inputdir', default='./reduce_out')
parser.add_argument('-o', '--outputdir', default='./patterns_out')
parser.add_argument('-v', '--verbose', action='store_true', default=False)

args = parser.parse_args()

inputdir = args.inputdir + "/" + args.runid
outputdir = args.outputdir + "/" + args.runid
dp_file = inputdir + '/all_discarded_points.csv'

ensure_dir(outputdir)

X, Y, Y1 = read_dataset(inputdir)

discarded_points = np.zeros(0)
if path.exists(dp_file):
    discarded_points = np.fromfile(dp_file, dtype=int, sep=',')
start_index = len(discarded_points)

X_train = X
Y_train = Y

if args.verbose:
    print(X_train.shape, Y_train.shape)

clf = LogisticRegression(random_state=2023
                         , penalty = 'none'
                         , max_iter=10000
                         , fit_intercept=True).fit(X_train, Y_train)
#need to make it input parameter
theta = 0.03

optimal_model_acc = accuracy_score(clf.predict(X_train), Y_train)
threshold = 1 - optimal_model_acc + theta
print("Rashomon threshold", threshold)

#inds = np.random.RandomState(seed=2023).permutation(np.arange(X_train.shape[0]))
#X_train = X_train[inds, :]
#Y_train = Y_train[inds]

tic = time.perf_counter()

patterns = compute_patterns_bin_fixed(
    X_train, Y_train, threshold,
    start_index,
    fit_intercept_init=True,
    verbose=args.verbose)

toc = time.perf_counter()
print(f"computed patterns in {toc - tic:0.4f} seconds")

# print('patterns')
# for item in patterns:
#     print(item.to_array(), accuracy_score(clf.predict(X_train), item.to_array()))
#     #print(accuracy_score(clf.predict(X_train), item.to_array()))
    
    
print(f"Found {len(patterns)} patterns")


file = open(outputdir + '/patterns_file', 'wb')
pickle.dump(patterns, file)
file.close()

div = compute_diversity(patterns)

print("Diversity", div)

file = open(outputdir + '/diversity_file', 'wb')
pickle.dump(div, file)
file.close()