import json
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np

colors = [
    '#1f77b4',  # Blue
    '#ff7f0e',  # Orange
    '#2ca02c',  # Green
    '#9467bd',  # Purple
    '#d62728',  # Red
    '#8c564b',  # Brown
    '#e377c2',  # Pink
    '#7f7f7f',  # Gray
    '#bcbd22',  # Yellow-Green
    '#17becf',   # Cyan
]
result_folder_path = '../../Results/'

def make_float_clear(value):
    value = float(value)
    if value.is_integer():
        return str(int(value))
    else:
        return str(value).rstrip('0').rstrip('.')

def load_json(file_path):
    with open(file_path, 'r') as file:
        data = file.read()
    while data[-1] in ['\n', '', ',']:
        data = data[:-1]
    data = f'[{data}]'
    data = json.loads(data)
    if not isinstance(data, list):
        raise ValueError("JSON data is not a list.")
    if not data:
        raise ValueError("JSON data is empty.")
    if not isinstance(data[0], dict):
        raise ValueError("JSON data does not contain dictionaries.")
    return data

def in_log_scale():
    ax = plt.gca()
    ax.set_xscale('log')
    ax.set_xlim(1e-7, 1e7)
    ax.xaxis.set_minor_locator(ticker.NullLocator())  # Remove minor ticks
    ax.xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=13))
    ax.xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, _: f'$10^{{{int(np.log10(x))}}}$'))

def default_scale():
    pass 

def get_data(file_path):
    jsn = load_json(file_path)
    possible_zs = set(j['candidate_z'] for j in jsn)
    possible_ds = set(j['candidate_d'] for j in jsn)
    data = dict([((j['test_case'],j['algorithm_name']),
                  dict([(z,{}) for z in possible_zs])) for j in jsn])
    for j in jsn:
        data[(j['test_case'], j['algorithm_name'])][j['candidate_z']][j['candidate_d']] = {
            'number_of_rounds': j['number_of_rounds'],
            'density': j['density'],
        }

    return data, possible_zs, possible_ds
         
def plot_single_algorithm(ax, dataset, algorithm_name, y_name, id, data, possible_zs, possible_ds):
    x = [z**2 for z in possible_zs]
    y = [min(data[(dataset, algorithm_name)][z][d][y_name] for d in possible_ds) for z in possible_zs]
    ax.plot(x, y, label=algorithm_name, linestyle='-', linewidth=1, color=colors[id])
    ax.scatter(x, y, marker='o', s=20, color = colors[id])

def single_plot(ax, algorithm_name, id, x, y):
    ax.plot(x, y, label=algorithm_name, linestyle='-', linewidth=1, color=colors[id])
    # ax.scatter(x, y, marker='o', s=20, color = colors[id])

def get_x_y(file_path, x_name, y_name):
    jsn = load_json(file_path)
    x_list = sorted(list(set(float(j[x_name]) for j in jsn)))
    y_list = [-1] * len(x_list)
    for j in jsn:
        x = float(j[x_name])
        y = float(j[y_name])
        x_ind = x_list.index(x)
        y_list[x_ind] = max(y_list[x_ind], y)
    return x_list, y_list

def get_x2_y(file_path, x_name, y_name):
    jsn = json.loads(open(file_path).read())
    x_list = sorted(list(set(float(j[x_name]) for j in jsn)))
    y_list = [-1] * len(x_list)
    for j in jsn:
        x = float(j[x_name])
        y = float(j[y_name])
        x_ind = x_list.index(x)
        y_list[x_ind] = max(y_list[x_ind], y)
    x2_list = [x**2 for x in x_list]
    return x2_list, y_list

def plot(dataset, algorithms, result_paths, params, x_name = 'z', y_name='density', limit=100000000, scale_func=in_log_scale):
    epsilon = params['epsilon']
    delta = params['delta']

    fig, ax = plt.subplots()
    ax.set_facecolor('white')
    ax.tick_params(colors='black', which='both')  
    for spine in ax.spines.values():  
        spine.set_edgecolor('black')

    for i,algorithm in enumerate(algorithms):
        x_list, y_list = get_x2_y(result_folder_path + result_paths[i], x_name, y_name)
        x_list = [x for x in x_list if x <= limit]
        y_list = y_list[:len(x_list)]
        single_plot(ax, algorithm, i, x_list, y_list)
    ax.set_xlabel('Z^2')
    ax.set_ylabel(y_name)
    ax.set_title(f'{dataset} - Density vs Z^2,  epsilon = {make_float_clear(epsilon)}')
    ax.legend()
    scale_func()



    plt.show()

if __name__ == "__main__":
    dataset = 'web-Google'
    algorithms = ['MPCBahmaniAlgorithm', 'OurMPCAlgorithm']
    result_paths = ['MPCBahmaniAlgorithm___web-Google___0.8___0.6.json', 'OurMPCAlgorithm___web-Google___0.8___0.2.json']
    params = {'epsilon': 0.2, 'delta': None}
    plot(dataset, algorithms, result_paths, params, 'candidate_z', 'density')