import os
import re
import pickle
import sys
import gmpy2
import numpy as np
from sklearn.metrics import accuracy_score

files_per_level = {}

sys.path.append('../code/')
import bit_vectors
for dir_name in os.listdir('./patterns_out/'):
    parts = re.split('_', dir_name)
    #key = dir_name
    key = "_".join(parts[0:len(parts)-1])
    #print(key)
    for file_name in os.listdir("./patterns_out/{}".format(dir_name)):
        if file_name == 'patterns_file':
            #print(key)
            file = open("./patterns_out/{}/patterns_file".format(dir_name), 'rb')
            patterns = pickle.load(file)
            file.close()
            print(dir_name, len(patterns))
            len_discard_points = len(np.fromfile("./reduce_out/"+dir_name+"/all_discarded_points.csv", dtype=int, sep=','))
            #print(len_discard_points)
            Y_true = np.load("./reduce_out/"+dir_name+"/Y_data0.npy")
            y_true_pattern_part = Y_true[len_discard_points:]
            for pattern in patterns:
                y_pattern = pattern.to_array()
                loss = accuracy_score(y_pattern, y_true_pattern_part)
                files_per_level.setdefault(key, []).append(loss)
                
res_dic = {}

theta = 0.03
for key, value in files_per_level.items():
    min_loss = np.min(value)
    in_rashomon_set = [x for x in value if x <= min_loss + theta]
    #print(key, len(in_rashomon_set))
    key_res = key.split("_")[0]+".csv"
    position = int(key.split("_")[-1]) - 1
    print(key, position, len(in_rashomon_set), min_loss)
    if key_res not in res_dic.keys():
        res_dic[key_res] = [-1,-1,-1,-1]
    res_dic[key_res][position] = len(in_rashomon_set)
        

for item in res_dic.items():
    print(item)
    
import pandas as pd
data_names = ['Wine 4', "Seeds 4", "Penguin 4", "Digits 0-4 4",    "Immunotherapy 4"]
new_data = ['wine4.csv', "seeds4.csv",  'penguin4.csv',  'digits044.csv',"immunotherapy4.csv" ]
data_folder = "../data_real/datasets_real_normalized/"

res = res_dic

num_features = {}
num_samples = {}

for data in new_data:
    df = pd.read_csv(data_folder+data)
    #print(data, df.shape)
    num_features[data] = df.shape[1]-1
    num_samples[data] = df.shape[0]
    print(data, df.shape)
    
from scipy.special import comb

def compute_linear_size(n,d):
    res = 0
    for k in range(0,d+1):
        res += comb(n - 1, k)
    res *= 2
    return res

import matplotlib.pyplot as plt
import matplotlib as mpl


plt.figure()
depth_arr = [1,2,3,4]
for i_data, key in enumerate(new_data):
    y = []
    for depth_id in range(len(res[key])):
        depth = depth_arr[depth_id]
        y_val = np.log(res[key][depth_id] / compute_linear_size(num_samples[key], depth))
        y += [y_val]
        
        #new_value = (int(num_features[key]) - 10) / 12 * (1 - 0.3) + 0.3
        
        plt.scatter(depth,y_val, s= 100, alpha = 1, zorder = 2, color = mpl.cm.tab10(i_data))#color = color[i_data]
    plt.plot(depth_arr[0:len(res[key])], y, alpha = 1, c = mpl.cm.tab10(i_data), linewidth = 1, label = data_names[i_data])
    
plt.xlabel("Number of non-zero coefficients", size = 20)
plt.ylabel("log pattern Rratio, %", size = 20)  
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.locator_params(axis='x', nbins=4)
plt.legend(loc = 'best', fontsize = 16, labelspacing = 0.1)
plt.savefig('rratio.pdf',bbox_inches = 'tight', dpi = 200)