import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt


# Define a cubic polynomial function
def cubic_polynomial(x, a, b, c, d):
    return a * x ** 3 + b * x ** 2 + c * x + d


ThoughtBFS = {
    'num_output_tokens': [
        23555, 28451, 29738, 59179, 18104, 31401, 19214, 46283, 40740,
        35337, 48874, 24451, 13493, 44768, 43921, 18850, 56051, 39004, 18198
    ],
    'Score': [
        3.1875, 1.46875, 0.25, 0.25, -0.1875, -0.65625, -1.15625, -1.21875,
        -1.40625, -1.5, -1.65625, -1.8125, -1.875, -1.875, -1.875, -2.21875,
        -2.21875, -3.1875, -4.65625
    ]
}

Our_method = {
    'num_output_tokens': [
        37140, 39816, 15787, 58351, 34595, 28638, 31070, 26081, 58490, 31820,
        14892, 15767, 24466, 28810, 11955, 11761
    ],
    'Score': [
        2, 1.03125, 0.90625, 0.65625, 0.25, 0.0625, 0.03125, -0.40625, -1.0625,
        -1.0625, -1.375, -1.46875, -1.75, -2.09375, -2.6875, -2.84375
    ]
}

Line = {
    'num_output_tokens': [
        14748, 19983, 15234, 68463, 11462, 42746, 45396, 31938, 24802, 3950,
        16082, 21166, 26626, 9545, 28581, 42540, 49388, 4999, 24446, 59421,
        10426
    ],
    'Score': [
        1.40625, 1.125, 0.875, 0.6875, -0.96875, -1.03125, -1.15625, -1.53125,
        -1.75, -1.75, -1.78125, -1.875, -1.875, -1.9375, -1.96875, -2.0625,
        -2.1875, -2.25, -2.8125, -2.84375, -3.75
    ]
}

Greedy = {
    'num_output_tokens': [
        29001, 10818, 21152, 13128, 15064, 9633, 42027, 41432, 21406, 33584,
        33090, 22603, 23347, 23628, 10273, 28980, 11332, 19635, 32661, 8796,
        14119, 42676, 11074
    ],
    'Score': [
        3.125, 1.5625, 0, -0.34375, -0.4375, -0.46875, -0.78125, -1.0625,
        -1.09375, -1.21875, -1.5625, -1.75, -1.875, -1.96875, -2.34375, -2.5625,
        -2.5625, -2.625, -3, -3.03125, -3.53125, -3.53125, -3.78125
    ]
}

BFS = {
    'num_output_tokens': [
        26372, 39275, 19727, 37244, 30221, 22566, 45559, 20349, 13215, 32172,
        17345, 53086, 19268, 34611, 27351, 10844, 39485, 41563, 44590, 27530,
        40479, 9549, 38001, 16431
    ],
    'Score': [
        2.5, 1.375, 1.28125, 0.125, -0.15625, -0.5, -0.6875, -0.75, -1.03125,
        -1.09375, -1.4375, -1.53125, -1.75, -1.84375, -1.84375, -2.1875,
        -2.53125, -2.6875, -2.75, -2.9375, -3, -3.125, -3.53125, -3.53125
    ]
}

# Define datasets
methods_data = {
    'Greedy': Greedy,
    'Line': Line,
    'Thought BFS': ThoughtBFS,
    'Our method': Our_method,
    'BFS': BFS
}

# Colors for each method
colors = ['purple', 'blue', 'green', 'red', 'orange']

# Plotting the data
plt.figure(figsize=(8, 8))

for (method, data), color in zip(methods_data.items(), colors):
    x = np.array(data['num_output_tokens'])
    y = np.array(data['Score'])
    plt.scatter(x, y, color=color, label=method)

    # Fit the cubic polynomial function to the data
    popt, _ = curve_fit(cubic_polynomial, x, y)

    # Generate x values for plotting the fitted curve
    x_fit = np.linspace(min(x), max(x), 500)
    y_fit = cubic_polynomial(x_fit, *popt)

    # Plot the fitted curve
    plt.plot(x_fit, y_fit, color=color, linestyle='--')

# Adding legend and labels
plt.legend(fontsize=14)
plt.xlabel('Number of Output Tokens', fontsize=15)
plt.ylabel('Score', fontsize=15)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.title('Computational Budget vs Performance', fontsize=18)
plt.savefig('plot_fig3_fitted.png')
