import numpy as np
import scipy

from sklearn.linear_model import LogisticRegression

import time

import warnings
warnings.filterwarnings("ignore", category=Warning)

#pred and labels are binary lists, len(pred) <= len(labels) and pred is right-padded to match length
def acc(pred, labels): 
    assert len(pred) <= len(labels)

    total = 0
    for i in range(len(pred)):
        a = pred[i]
        b = labels[i]

        if a == b:
            total += 1
    
    accuracy = total / len(pred)
    num_wrong = len(pred) - total

    return accuracy, num_wrong

#returns list form of binary digits of x, parity digit is at index -1
def int_to_list(x, num_samples):
    ret = []
    for i in range(num_samples):
        ret.append(x % 2)
        x = x >> 1
    ret.reverse()
    return ret

#calculates theoretically conjectured value for number of patterns
def calc_predicted(n, d): 
    total = 0
    for i in range(d+1):
        total += scipy.special.binom(n-1, i)
    return total * 2

def find_patterns(
    data,
    num_samples = 1000,
    max_wrong = -1,
    labels = None,
    verbose = False,
):
    dim = len(data[0])
    parameters = np.array([[0 for n in range(dim+1)]])

    if max_wrong == -1:
        max_wrong = num_samples

    start_time = time.time()

    X_train = data

    if verbose:
        print('-----------------------')
        print('Num Features:', dim)
        print('# Points, # Rashomon Patterns, # Patterns:')
    
    pattern_counts = []
    queue = []
    queue.append(0)
    queue.append(1)
    for i_depth in range(1, num_samples):
        new_queue = [] #breadth first search
        pattern_count = [0 for i in range(max_wrong+1)]
        for i_elem in queue:
            new_elem = [i_elem * 2, i_elem * 2 + 1]
            for e in new_elem:
                e_arr = int_to_list(e, num_samples)[-i_depth-1:] #truncate digits based on depth
                    
                x_s = X_train[:i_depth+1, :]
                y_s = e_arr

                if e == 0 or e == (1 << (i_depth+1)) - 1:
                    accuracy = 1.0
                else:
                    clf = LogisticRegression(penalty="none").fit(x_s, y_s)

                    preds = clf.predict(x_s)
                    accuracy, _ = acc(preds, y_s)
                
                if labels is not None:
                    _, num_wrong = acc(y_s, labels)
                else:
                    num_wrong = 0

                if accuracy == 1.0 and num_wrong <= max_wrong:
                    new_queue.append(e)

                    pattern_count[num_wrong] += 1

                    if i_depth == num_samples - 1:
                        coeffs = np.append(clf.coef_, clf.intercept_)
                        parameters = np.concatenate((parameters, np.array([coeffs])), axis=0)
        queue = new_queue
        pattern_counts.append(pattern_count)

        if verbose:
            print(i_depth + 1, len(queue), calc_predicted(i_depth+1, dim))

    total_time = time.time() - start_time
    return queue, parameters[1:], total_time, pattern_counts