import os
import argparse

import numpy as np
from sklearn.linear_model import LogisticRegression

from compute_patterns import find_patterns, acc

#shuffling the dataset reduces computation empirically
def shuffle(X, y):
    rng = np.random.default_rng(3)
    indices = rng.permutation(np.arange(X.shape[0]))
    X = X[indices] 
    y = y[indices]
    
    return X, y

#shuffles the data according to the distance to decision boundary
def shuffle_distance(X, y, clf):
    scores = clf.decision_function(X)

    scores = np.abs(scores)
    indices = np.argsort(scores)

    X = X[indices]
    y = y[indices]

    return X, y

#shuffles the data according to the distance to decision boundary, rotating between tp,fp,tn,fn
def shuffle_rotate(X, y, clf):
    scores = clf.decision_function(X)
    scores = np.abs(scores)

    indices = np.argsort(scores)

    X = X[indices]
    y = y[indices]

    preds = clf.predict(X)
    tp = ((preds == 1) & (y == 1)).nonzero()[0]
    fp = ((preds == 1) & (y == 0)).nonzero()[0]
    tn = ((preds == 0) & (y == 0)).nonzero()[0]
    fn = ((preds == 0) & (y == 1)).nonzero()[0]

    indices = []
    rotation = [tp, fp, tn, fn]

    for i in range(X.shape[0]):
        for arr in rotation:
            if i >= arr.shape[0]:
                continue
            indices.append(arr[i])
    
    X = X[indices]
    y = y[indices]

    return X, y

def calculate_rashomon(data_dir, output_dir, theta, verbose):
    X = np.load(f'{data_dir}/X_data.npy')
    y = np.load(f'{data_dir}/y_data.npy')

    # X = X[::2]
    # y = y[::2]

    num_samples = X.shape[0]

    clf = LogisticRegression(penalty="none").fit(X, y)
    preds = clf.predict(X)
    _, opt = acc(preds, y)

    # X, y = shuffle(X, y)
    # X, y = shuffle_distance(X, y, clf)
    X, y = shuffle_rotate(X, y, clf)

    max_wrong = opt + int(theta * num_samples)

    if verbose:
        print('Minimum Wrong:', opt)
        print('Theta * Samples:', int(theta * num_samples))
        print('Maximum Wrong:', max_wrong)

    patterns, params, total_time, pattern_counts = find_patterns(
        X, 
        verbose=verbose, 
        labels=y, 
        num_samples=num_samples, 
        max_wrong=max_wrong,
    )

    if verbose:
        print('Number of Patterns:', len(patterns))

    os.makedirs(output_dir, exist_ok=True)

    np.save(f'{output_dir}/patterns', arr=patterns)
    np.save(f'{output_dir}/counts', arr=pattern_counts)
    np.save(f'{output_dir}/y_data', arr=y)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Calculates the size of the Rashomon set for a given dataset')
    parser.add_argument('-d', '--data_dir', default='./datasets/wine_pca')
    parser.add_argument('-o', '--output_dir', default='./rashomon_sets/wine_pca')
    parser.add_argument('-t', '--theta', default=0.02, type=float)
    parser.add_argument('-v', '--verbose', action='store_true')

    args = parser.parse_args()

    calculate_rashomon(
        data_dir = args.data_dir, 
        output_dir = args.output_dir,
        theta = args.theta,
        verbose = args.verbose,
    )