import numpy as np
from gmpy2 import hamdist

import bit_vectors

from sklearn.metrics import accuracy_score
from sklearn.svm import SVC
from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression

#To optimize:
#There is no reason in running logistic regression for num points < min (number of dimentions, num samples *  rashomon threshold)

def loss_score(a, b, weight):
    s = a != b
    return s.sum() / weight

def compute_patterns_bin(X_train
                            , Y_train
                            , theta_threshold
                            , start_index
                            , fit_intercept_init = False
                            , verbose = False
                            , random_init = 2023):
    N = X_train.shape[0]
    variable_sample_size = N - start_index
    
    # generate sequence for fixed points
    init_seq = bit_vectors.BinSequence(value=0, length=0)
    
    # We can optimize here by 
    for i in range(start_index):
        if Y_train[i] == 0:
            init_seq.append_0()
        else:
            init_seq.append_1()
    
    # queue for BFS tree
    queue = []
    queue.append(init_seq.copy().append_0())
    queue.append(init_seq.append_1())
    
    # counter for number of pruned branches
    prune_counter = 0

    for i_depth in range(1, variable_sample_size):
        if verbose:
            print("sample id", i_depth
                  , "number of items in queue", len(queue)
                  , "pruned branches", prune_counter)
        for i_elem in range(len(queue)):
            elem = queue.pop(0)
            new_elem = [elem.copy().append_1(), elem.append_0()]
            for e in new_elem:
                if e.x == 0 or e.all_bits_set():
                    if i_depth < variable_sample_size - 1:
                        queue.append(e)
                    continue
                    
                x_s = X_train[:e.len, :]
                y_s = e.to_array()

                clf = LogisticRegression(random_state=random_init
                                         , penalty='none'
                                         , max_iter=10000
                                         , fit_intercept=fit_intercept_init).fit(x_s, y_s)
                #clf = LinearSVC(C = 100).fit(x_s, y_s)
                
                acc = clf.score(x_s, y_s)
                
                orig_y_s = Y_train[:e.len]
                loss = 0
                if sum(orig_y_s) != e.len and sum(orig_y_s) != 0:
                    loss = loss_score(y_s, orig_y_s, N)

                #print("     loss", orig_y_s, y_s, loss)
                if acc == 1.0 and loss <= theta_threshold:
                    queue.append(e)
                    if verbose:
                        print("   ", i_depth, "kept. prefix loss", loss)
                if acc == 1.0 and loss > theta_threshold:
                    prune_counter += 1
                    if verbose:
                        print("   ", i_depth, "pruned. prefix loss", loss)
                    
    return queue

def compute_patterns_bin_fixed(X_train
                            , Y_train
                            , theta_threshold
                            , start_index
                            , fit_intercept_init = False
                            , verbose = False
                            , random_init = 2023):
    N = X_train.shape[0]
    variable_sample_size = N - start_index
    
    y_init = Y_train[0:start_index]
    
    # generate sequence for fixed points
    init_seq = bit_vectors.BinSequence(value=0, length=0)
    
#     # We can optimize here by 
#     for i in range(start_index):
#         if Y_train[i] == 0:
#             init_seq.append_0()
#         else:
#             init_seq.append_1()
    
    # queue for BFS tree
    queue = []
    queue.append(init_seq.copy().append_0())
    queue.append(init_seq.append_1())
    
    # counter for number of pruned branches
    prune_counter = 0

    for i_depth in range(1, variable_sample_size):
        if verbose:
            print("sample id", i_depth
                  , "number of items in queue", len(queue)
                  , "pruned branches", prune_counter)
        for i_elem in range(len(queue)):
            elem = queue.pop(0)
            new_elem = [elem.copy().append_1(), elem.append_0()]
            for e in new_elem:
                if e.x == 0 or e.all_bits_set():
                    if i_depth < variable_sample_size - 1:
                        queue.append(e)
                    continue
                    
                x_s = X_train[:e.len + start_index, :]
                y_s = e.to_array()
                y_s = np.concatenate((y_init, y_s))

                clf = LogisticRegression(random_state=random_init
                                         , penalty='none'
                                         , max_iter=10000
                                         , fit_intercept=fit_intercept_init).fit(x_s, y_s)
                #clf = LinearSVC(C = 100).fit(x_s, y_s)
                
                acc = clf.score(x_s, y_s)
                
                orig_y_s = Y_train[:e.len + start_index]
                loss = 0
                if sum(orig_y_s) != e.len + start_index and sum(orig_y_s) != 0:
                    loss = loss_score(y_s, orig_y_s, N)

                #print("     loss", orig_y_s, y_s, loss)
                if acc == 1.0 and loss <= theta_threshold:
                    queue.append(e)
                    if verbose:
                        print("   ", i_depth, "kept. prefix loss", loss)
                if acc == 1.0 and loss > theta_threshold:
                    prune_counter += 1
                    if verbose:
                        print("   ", i_depth, "pruned. prefix loss", loss)
                    
    return queue


def compute_diversity(patterns):
    P = len(patterns)
    assert(P > 0)
    div = 0.
    for i in range(P):
        for j in range(i + 1, P):
            div += hamdist(patterns[i].x, patterns[j].x)
    return 2.0 * div / P / P / patterns[0].len