
#%%
'''
Execute in VS code using the cell environemnt. 
First cell is imports, second cell is the main experiment, third cell is plotting, last cell computes empirical phi values.
'''
from concurrent.futures import ProcessPoolExecutor
from submodular_data import *
import pickle
import heapq
import matplotlib.patheffects as pe
import math
import random
import json
import time
import numpy as np
from datetime import datetime
import os
from concurrent.futures import ThreadPoolExecutor
from numba import njit


def save_p_values(p_values, filename):
    with open(filename, 'wb') as file:
        pickle.dump((p_values, data), file)

def load_p_values(filename):
    try:
        with open(filename, 'rb') as file:
            return pickle.load(file)
    except FileNotFoundError:
        return None, None



@njit
def _compute_p(Fis, F_vals):
    M, N = Fis.shape
    p = np.zeros(N, np.float64)
    for i in range(N):
        maxv = 0.0
        for j in range(M):
            if F_vals[j] > 0.0:
                v = Fis[j, i] / F_vals[j]
                if v > maxv:
                    maxv = v
        p[i] = maxv
    return p

def calc_p_values(data, S=None, E_sample=None, max_workers=None):
    if S is None: S = []
    X, E, N = data.X, data.E, data.N
    if E_sample is None:
        E_sample = E

    def eval_f(e):
        fis = data.F(X, S + [e], get_Fis=True)
        return fis, fis.sum()

    # 1) Parallel Fis/E_sample loop
    with ThreadPoolExecutor(max_workers=max_workers) as exe:
        results = list(exe.map(eval_f, E_sample))

    Fis = np.vstack([r[0] for r in results])        # shape (M, N)
    F_vals = np.array([r[1] for r in results])      # shape (M,)

    # 2) Fast p-compute via Numba
    return _compute_p(Fis, F_vals)




def save_results(results, filename_base="results"):
    # Custom handler for JSON serialization of NumPy types
    def default_handler(obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')


    data_type = results["data_type"]
    algs = results["algs"]
    betas = results["betas"]
    eps_str = str(results["eps"])
    date_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    algs_str = "_".join(algs)
    betas_str = "_".join(map(str, betas))
    filename = f"{filename_base}_{data_type}_{algs_str}_{betas_str}_{eps_str}_{date_time}.json"

    # Write results to file using the custom handler for non-serializable types
    results_dir = "results"
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    with open(os.path.join(results_dir, filename), 'w') as file:
        json.dump(results, file, indent=4, default=default_handler)

    print(f"Results saved to {filename}")


def load_results(filename):
    try:
        with open(filename, 'r') as file:
            data = json.load(file)
        
    
    except FileNotFoundError:
        print(f"No saved results found at {filename}. Starting from scratch.")
        return None

    # Convert betas to floats
    for alg in data["algs"]:
        if alg in data:
            for beta_str in list(data[alg].keys()):  # Use list to avoid dictionary size change error
                beta_float = float(beta_str)
                data[alg][beta_float] = data[alg].pop(beta_str)
    return data


def evaluate_F(data,i, batch, S, E, w):
    e = E[i]
    F_val = data.F(batch, S + [e], w)
    return i, F_val

def sample_batch(X,p,alpha,use_uniform=False):
    
    batch =[]
    w=[]
    while len(batch) == 0:
        for i in range(len(X)):
            alpha_i = min(alpha*p[i],1)
            if (random.random() < alpha_i):
                w.append(1/alpha_i)
                batch.append(X[i])
    return w, batch

def find_alpha(p, beta):
    tolerance=1e-4
    alpha = 1.0
    expected_samples = sum(min(1, alpha * pi) for pi in p)
    N = len(p)
    while expected_samples < N * beta:
        alpha *= 2
        expected_samples = sum(min(1, alpha * pi) for pi in p)

    # Binary search between (alpha/2, alpha)
    low, high = 0, alpha
    iter_count = 0

    while low <= high:
        mid = (low + high) / 2
        expected_samples = sum(min(1, mid * pi) for pi in p)
        if np.isclose(expected_samples, N * beta, atol=tolerance):
            return mid
        elif expected_samples < N * beta:
            low = mid
        else:
            high = mid

        iter_count += 1

    return (low + high) / 2


def execute_alg(data,p,k, alpha, alg, eps=None):
    S = []
    pq = []
    X=data.X
    E=data.E
    N=data.N

    sample_E = eps is not None
    if sample_E:
        lazy_sample_size = int((len(E)/k) *math.log2(1/eps))
        print("lazy_sample_size: ", lazy_sample_size)

    
    remaining = set(range(len(E)))
    results = dict() #k vs result
    results["preproc_calls"]=0
    results["k"] = []

    func_calls = 0
    use_uniform = "uniform" in alg
    #remove "uniform" from alg
    alg = alg.replace("_uniform","")
    w = None
    batch = X

    if alg == "my":
        if not use_uniform:
            results["preproc_calls"] = N*len(E)
        full_batch = False
        
        if not use_uniform:
            batch_sizes = []
        #sample_E=False

    if alg == "lazy":
        #sample_E=True
        full_batch = True

    if alg == "full":
        #sample_E=False
        full_batch = True
        w = None
        batch = X
    
    if alg == "sparsifier":
        if not use_uniform:
            results["preproc_calls"] = N*len(E)
        full_batch = True
        w,batch = sample_batch(X,p,alpha, use_uniform)

    #Only full alg should pay the below
    if not sample_E:
        if alg == "full" or alg == "uniform":
            func_calls += len(batch)* len(remaining)
        for i in remaining:
            e = E[i]
            F_val = data.F(batch, [e], w)
            heapq.heappush(pq, (-F_val, i, e))

    print("preprocessing func calls: ", results["preproc_calls"])
    sol_val = 0
    for j in range(k):
        #start iter timer
        iter_time = time.time()
        print("iter: ",j, "out of ", k, "#remaining: ", len(remaining))
        
        rem = [i for i in remaining if data.independent(S + [E[i]])]
        if sample_E:
            print(len(remaining))
            rem = random.sample(sorted(rem), min(lazy_sample_size, len(rem)))
            print("batch size: ",len(batch))

        if full_batch == False:
            
            w,batch = sample_batch(X,p,alpha,use_uniform)

            if not use_uniform:
                batch_sizes.append(len(batch))
        
            
        best_F = 0
        best_e = None
        best_i = -1
        
        if sample_E:
            pq_orig = pq
            to_handle = set(rem)
            pq = [item for item in pq if item[1] in to_handle]
            to_handle =  to_handle - set([item[1] for item in  pq]) 
            for i in to_handle:
                e = E[i]
                func_calls += len(batch)
                F_val = data.F(batch, S+[e], w)
                pq.append((-F_val, i, e))
                pq_orig.append((-F_val, i, e))
            heapq.heapify(pq)
        

        F0 = data.F(batch, S, w)
        func_calls += len(batch)
        while len(pq) > 0:
            # Lazy update step
            #start timer
            t = time.time()
            top = heapq.heappop(pq)
            e_f, i, e = top
            if i not in rem:
                continue  # Skip if this element has already been selected
            
            if i == best_i:
                break
            
            F_val = data.F(batch, S + [e], w) - F0
            heapq.heappush(pq, (-F_val, i, e))
            func_calls += len(batch)
            if F_val >= best_F:
                best_F = F_val
                best_e = e
                best_i = i


        if sample_E:
            pq_new = [item for item in pq_orig if item[1] not in rem]
            pq = pq+pq_new
            heapq.heapify(pq)
        S = S + [best_e]

        assert(best_i >-1)
        remaining.remove(best_i)
        sol_val = data.F(X, S, orig=False)
        
        print("iter time: ", time.time()-iter_time)
        results["k"].append({"sol_val":sol_val, "func_calls":func_calls, "S":S})
        print("solution val:", results["k"][-1]["sol_val"], "func calls: ", results["k"][-1]["func_calls"])
    
    return results



#%%
'''
Cell to run the experiments. Save results to the results folder.
'''
if __name__ == '__main__':
    N=None
    datasets = [ImageData(type="cifar100", N=N), ImageData(type="fashion_mnist", N=N), BipartiteGraph(N=N), TaxiData(N=N), InfluenceMaxData()]
    algs = ["my", "sparsifier", "my_uniform", "sparsifier_uniform"]
    
    #betas = [0.1, 0.01, 0.001, 0.0001]
    epsilons = [None]
    num_repeat = 1
    k = 20
    skip_full = False
    use_multiprocessing = True

    for data in datasets:
        #skip 0.0001 for influence max
        if data.__class__.__name__ == "InfluenceMaxData":
            betas = [0.1, 0.01, 0.001]
        else:
            betas = [0.01, 0.001, 0.0001]
        start_time = time.time()  # Start timing
        print("running dataset", str(data))
        filename = f'p_values_{str(data)}.pkl'
        p,saved_data = load_p_values(filename)
        if p is not None:
            data = saved_data
            print("Loaded p values from file.")
        else:
            
            print("init data")
            data.init()
            if data.__class__.__name__ != "Flywire":
                print("calc p values")
                p = calc_p_values(data)
                save_p_values(p, filename)
                print(f"Saved p values to {filename}.")
            else:
                print("Flywire data, setting p to 1/N")
                p = np.ones(len(data.X)) / len(data.X)

        X=data.X
        E=data.E
        N=data.N
        
        max_workers=12
        print("starting main experiments")
        for eps in epsilons:
            results = {
                "data_type": str(data),
                "num_repeat": num_repeat,
                "k": k,
                "betas": betas,
                #"epsilons": epsilons,
                "eps": eps,
                "algs": algs }
            if not skip_full:
                print("running baseline")
                if eps is None:
                    results["full"] = execute_alg(data,p,k,0,"full",eps)
                else:
                    results["full"] = []
                    if use_multiprocessing:
                        with ProcessPoolExecutor(max_workers=max_workers) as executor:
                            future_results = []
                            for i in range(5):
                                future = executor.submit(execute_alg, data,p,k,0,"full",eps)
                                future_results.append(future)
                        for future in future_results:
                            results["full"].append(future.result())
                    else:
                        for i in range(5):
                            res = execute_alg(data,p,k,0,"full",eps)
                            results["full"].append(res)

            with ProcessPoolExecutor(max_workers=max_workers) as executor:
                # Execute experiments in parallel and collect results
                future_results = []
                for beta in betas:
                    print("getting alpha for beta: ", beta)
                    alpha = find_alpha(p,beta)
                    alpha1 = alpha
                    print("alpha is ", alpha1)
                    for alg in algs:
                        if alg not in results:
                            results[alg] = {}
                        if beta not in results[alg]:
                            results[alg][beta] = []
                        p1=p
                        if "uniform" in alg:
                            p1 = np.ones(len(X)) / len(X)
                            alpha1 = beta*len(X)
                        for i in range(num_repeat):
                            if use_multiprocessing:
                                future = executor.submit(execute_alg, data, p1, k, alpha1, alg, eps)
                                future_results.append((alg, beta, future))
                            else:
                                res = execute_alg(data, p1, k, alpha1, alg, eps)
                                results[alg][beta].append(res)
            
                

            # Organize results
            if use_multiprocessing:
                for alg, beta, future in future_results:
                    results[alg][beta].append(future.result())


            save_results(results)
            elapsed_time = time.time() - start_time  # Calculate elapsed time
            print(f"Total execution time: {elapsed_time} seconds")

    
#%%
'''
Plotting code. Exceute after running the experiments.
'''

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.ticker as mticker
from matplotlib.lines import Line2D

from matplotlib.ticker import LogFormatterMathtext

import matplotlib.pyplot as plt
import numpy as np



colors = ["olive","red", "blue", "green"]
def plot_results(results, betas, ratio=False, show_error_bars=True, large = False):
    """
    Plot all results in a grid with a single shared legend.
    :param results: Dictionary containing the results data.
    :param betas: List of beta values.
    :param show_error_bars: Whether to include error bars.
    """
    if betas is None:
        betas = results['betas']
    betas = sorted(betas, reverse=True)
    plt.rcParams.update({'font.size': 18}) 
    plt.figure(figsize=(10, 6))

    #colors = plt.cm.viridis(np.linspace(0, 0.6, len(betas)))
    beta_color_map = dict(zip(betas, colors))

    num_res = len(results)  # Number of rows in the grid
    if large:
        # one dataset per row, 2 columns (utility / cost)
        fig, axes = plt.subplots(num_res, 2, figsize=(30, 8 * num_res))
    else:
        # original: 2 rows, num_res columns
        fig, axes = plt.subplots(2, num_res, figsize=(7 * num_res, 10))

    #fig, axes = plt.subplots(2, num_res, figsize=(7*num_res, 10))  # 2 Rows, Number of Columns equal to number of rows in the grid
    for i in range(num_res):
        print("plotting ", i)
        res = results[i]
        if large:
            ax1, ax2 = axes[i]       # axes is shape (num_res,2)
        else:
            if num_res == 1:
                (ax1, ax2) = axes
            else:    
                (ax1, ax2) = axes[:, i]  

        dataset_label = res["data_type"]
        if dataset_label == "BipartiteGraph":
            dataset_label = "Discogs"
        else:
            x = dataset_label.split("_")
            if x[0] == "TaxiData":
                dataset_label= "Uber Pickups"
            elif x[0] == "InfluenceMaxData":
                dataset_label = "Facebook"
            elif x[0] == "ImageData":
                if x[1] == "cifar100":
                    dataset_label = "CIFAR100"
                elif x[1] == "fashion":
                    dataset_label = "Fashion MNIST"
        
        #add eps to label
        if "eps" in res.keys() and res["eps"] is not None:
            dataset_label += f" (\\epsilon={res['eps']})"
        axes_font_size = 20
        # if large:
        #     axes_font_size = 15
        plot_results_aux(ax=ax1, results=res, betas=betas, ratio=ratio, show_error_bars=show_error_bars, plot_func_calls=False, axes_font_size=axes_font_size)
        ax1.set_title( "$\\bf{"+dataset_label+"}$ \n Utility")
        
        plot_results_aux(ax=ax2, results=res, betas=betas, ratio=ratio, show_error_bars=show_error_bars, plot_func_calls=True, count_preproc=True, axes_font_size=axes_font_size)
        ax2.set_title(f"Cost")# (excluding preprocessing)")

    # Combine legends from all subplots if needed, here we assume same legend for all
    legend_line_width = 7
    effects = [pe.withStroke(linewidth=legend_line_width, foreground="black")]
    legend_elements = [
        #add legend for full as a dot, not a line  
        Line2D([0], [0], color='white', lw=legend_line_width-2, linestyle='-', label='u-minibatch', path_effects=effects),
        Line2D([0], [0], color='white', lw=legend_line_width-2, linestyle='--', label='u-sparsifier', path_effects=effects),
        Line2D([0], [0], color='black', lw=legend_line_width, linestyle='-', label='w-minibatch'),
        Line2D([0], [0], color='black', lw=legend_line_width, linestyle='--', label='w-sparsifier')
        #add markers to legend without lines
    ]

    # Add lines for each beta value
    for beta in betas:
        color = beta_color_map[beta]
        beta_line = Line2D([0], [0], color=color, lw=legend_line_width, label=f'$\\beta=10^{{{math.floor(math.log10(beta))}}}$')
        legend_elements.append(beta_line)

    # Create the custom legend
    fig.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(0.5, 1.01), ncol=len(legend_elements), frameon=True, fontsize=25)

    
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()


def plot_results_aux(ax, results, betas, ratio=False, show_error_bars=True, plot_func_calls=True, count_preproc=False, axes_font_size = 20):

    beta_color_map = dict(zip(betas, colors))
    # Extract k values
    start_k = 1
    end_k = results['k']
    ks = range(start_k, end_k+1)

    # Plot for 'full' algorithm
    if plot_func_calls:
        y_value = "func_calls" 
    else:
        y_value = 'sol_val'
    has_full = "full" in results
    
    eps = None
    if "eps" in results.keys():
        eps = results["eps"]
    print("eps: ", eps)
    if has_full:
        if eps is None:
            sol_vals_full = [item[y_value] for item in results['full']['k']]
        else:
            means = []
            for k in range(start_k-1, end_k):
                num_repeat = len(results["full"])
                sol_vals = [results["full"][i]['k'][k][y_value] for i in range(num_repeat)]
                means.append(np.mean(sol_vals))
            sol_vals_full = means

    #print(sol_vals_full)        
    if has_full and not ratio:
        ax.plot(ks, sol_vals_full[start_k-1:end_k], "o:", label="full", color="blue")

    # Function to calculate mean and standard deviation
    def calculate_stats(res, alg, beta):
        means = []
        stds = []
        for k in range(start_k-1, end_k):
            if y_value == "func_calls" and count_preproc:
                sol_vals = [res[alg][beta][i]['preproc_calls']+res[alg][beta][i]['k'][k][y_value] for i in range(res['num_repeat'])]
            else:
                #print all params of res
                print(alg, beta, k, y_value)
                sol_vals = [res[alg][beta][i]['k'][k][y_value] for i in range(res['num_repeat'])]
            
            if ratio:
                sol_vals = [sol_vals[i] / sol_vals_full[k] for i in range(res['num_repeat'])]
            
            means.append(np.mean(sol_vals))
            stds.append(np.std(sol_vals))
        return means, stds

    # Plot for 'my' and 'sparsifier' algorithms
    print(results['algs'])
    for alg in set(['my', 'sparsifier', 'lazy', 'my_uniform', "sparsifier_uniform"]) & results.keys():
        beta_vals = betas
        for beta in beta_vals:
            if beta in results[alg]:
                means, stds = calculate_stats(results, alg, beta)
                
                linestyle = '-' if 'my' in alg else '--'  # solid for 'my', dashed for 'sparsifier'
               
                marker=""
                color = beta_color_map[beta]
                label_alg = alg
                if alg=="my":
                    label_alg = f'minibatch'
                label = f'${label_alg} (\\beta=10^{{{math.floor(math.log10(beta))}}})$'
                
                if show_error_bars:
                    ax.errorbar(ks, means, yerr=stds, fmt=marker, linestyle=linestyle, label=label,
                                 color=color, elinewidth=2, capsize=5, alpha=0.7)
                else:
                    
                    if "uniform" in alg:
                        line, = ax.plot(ks, means, marker+linestyle, label=label, color="white", linewidth=2)
                        shadow_effect = pe.withStroke(linewidth=4, foreground=color)
                        line.set_path_effects([shadow_effect])
                    else:
                        line, = ax.plot(ks, means, marker+linestyle, label=label, color=color, linewidth=3)
                    
    

    
    ax.set_xlabel('$k$', fontsize=axes_font_size)
    
    #plt.yscale('log')
    if plot_func_calls:
        ylabel = '#Queries'
        
        if not count_preproc:
            ax.set_yscale("log")
            ylabel += " (log scale)"
        else:
            ax.set_yscale("log")
            ylabel += " (log scale)"
        if ratio:
            ylabel = 'Relative ' + ylabel 
        ax.set_ylabel(ylabel, fontsize=axes_font_size)
    else:
        ax.set_ylabel('Solution Value', fontsize=axes_font_size)
        if ratio:
            ax.set_ylabel('Relative performance', fontsize=axes_font_size)
    ax.grid(True)
    ax.set_xticks(range(min(ks)-1, max(ks)+5, 5))
    

    

# Call the function
if __name__ == '__main__':
    # Example usage

    directory = "results/"
    betas_to_plot = [0.1,0.01,0.001, 0.0001]#,0.0001]# Set this to the betas you want to plot, or None to plot all

    # Loop over each file in the directory
    eps = [None] #None for lazy greedy, 0.1,0.2 for stochastic greedy
    to_plot = []
    for filename in os.listdir(directory):
        if filename.endswith(".json"):
            file_path = os.path.join(directory, filename)
            print(f"Processing file: {filename}")
            results = load_results(file_path)
            print("data type: ", results["data_type"])
            if eps is not None and "eps" in results.keys() and results["eps"] not in eps:
                continue
            to_plot.append(results)
    
    #sort to_plot by eps and dataset name
    to_plot.sort(key=lambda x: x["data_type"])

    plot_results(to_plot, betas_to_plot, False, False, True) #last parameter is large, set to True for large plots



#%%
'''
Calculated the empirical phi values that appear in the paper.
'''
def format_value(val):
        if val < 0.01:
            # Format the number in scientific notation
            exponent = int(f"{val:.2e}".split('e')[1])  # Get the exponent part of the scientific notation
            base = val / (10**exponent)  # Get the base part
            return "{:.2f} 10^{}".format(base, exponent)
        else:
            # Format as x.xx for numbers >= 0.01
            return "{:.2f}".format(val)

if __name__ == '__main__':
    datasets = [InfluenceMaxData(),ImageData(type="cifar100"), ImageData(type="fashion_mnist"), BipartiteGraph(), TaxiData(), InfluenceMaxData()] #, Flywire(sample=False)]
    
    #data = ImageData(type="fashion_mnist")
    All_Fis = []
    for data in datasets:
        filename = f'p_values_{str(data)}.pkl'
        p,saved_data = load_p_values(filename)
        if p is not None:
            data = saved_data
            print("Loaded p values from file.")
        else:
            p = calc_p_values(data)
            save_p_values(p, filename)
            print(f"Saved p values to {filename}.")

        X=data.X
        E=data.E
        N=data.N

        Fis = []
        E_sample = E 
        F_vals = np.zeros(len(E_sample))
        for j,e in enumerate(E_sample):
            Fis.append(data.F(X,[e], get_Fis=True))
            F_vals[j] = np.sum(Fis[-1])
        
        Fis = np.array(Fis)
        All_Fis.append(Fis)
    
    for i in range(len(All_Fis)):
        dataset = datasets[i]
        Fis = All_Fis[i]

        z = data.F(X,[], get_Fis=True)
        print("dataset: ", dataset)
        phi1 = np.max(np.mean(Fis, axis=1)) / np.max(Fis)
        phi2 = np.min(np.mean(Fis, axis=1)) / np.max(Fis)
        print("phi1: ", format_value(phi1), "phi2: ", format_value(phi2))
# %%
