import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.optimize import curve_fit

plt.rcParams["font.family"] = "DejaVu Sans"

def custom_model(X, a, alpha, b, beta, c, s, e1, f1, k, h, es, fs):
    x1 = X[:, 0]
    x2 = X[:, 2]
    x3 = X[:, 1]
    x4 = X[:, 3]
    x5 = X[:, 4]
    return (e1 * x4 + f1 / x4 + (es * np.power(x5, 2) + fs * x5)) * (np.power(x1, alpha) + k * np.power(x3, alpha) + h * (x3 / x1)) + a * np.power(x1, alpha) + c * np.power(x3, alpha) + b * np.power(x2, beta) + s

def fit_and_plot(data, points_val, output_path):
    data = np.array(data)
    points_val = np.array(points_val)
    
    x_data_val = points_val[:, :5]
    y_data_val = points_val[:, 5]
    
    x_data = data[:, :5]
    y_data = data[:, 5]

    initial_guess = [69.2343, -0.2368, 25970.0621, -0.5162, 69.2343, 1.0, 1.0, 1.0, 0.01, 0.01, 1.0, 1.0]

    bounds = (
        [-np.inf, -1.0, 0.0001, -1.0, 0.0001, 0.0001, -np.inf, -np.inf, 0.0001, -np.inf, -np.inf, -np.inf],
        [np.inf, 0, np.inf, 0, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf])

    params, params_covariance = curve_fit(
        custom_model,
        x_data,
        y_data,
        p0=initial_guess,
        bounds=bounds,
        maxfev=50000)   

    a, alpha, b, beta, c, s, e1, f1, k, h, es, fs = params
    
    print(f'$Loss = ({e1:.4f} \cdot G + {f1:.4f} / G + {es:.4f} \cdot S(S+{fs:.4f})) \cdot (N**{alpha:.4f} + {k:.4f} \cdot Na**{alpha:.4f} + {h:.4f} * Na/N) + {a:.4f} \cdot N**{alpha:.4f} + {c:.4f} \cdot Na**{alpha:.4f} + {b:.4f} \cdot D**{beta:.4f} + {s:.4f}$')

    y_pred = custom_model(x_data, a, alpha, b, beta, c, s, e1, f1, k, h, es, fs)
    y_pred_val = custom_model(x_data_val, a, alpha, b, beta, c, s, e1, f1, k, h, es, fs)

    plt.figure(figsize=(10, 10))

    N_values = x_data[:, 0]
    D_values = x_data[:, 2]
    
    unique_N = np.unique(N_values)
    color_range = np.linspace(-0.4, 1.0, len(unique_N))
    colors = plt.cm.coolwarm(color_range)
    light_colors = np.clip(colors * 0.9 + 0.05, 0, 1)
    N_color_map = {n: color for n, color in zip(unique_N, light_colors)}
    
    all_D = np.concatenate([D_values, x_data_val[:, 2]])
    D_min, D_max = np.min(all_D), np.max(D_values)
    D_scaled = 40 + 60 * (D_values - D_min) / (D_max - D_min)
    
    i = 0
    cnt = 0
    for n, d, y, yp in zip(N_values, D_scaled, y_data, y_pred):
        plt.scatter(y, yp, marker='o', color=N_color_map[n], s=d, 
                   alpha=0.8, edgecolors='w', linewidth=0.2)
        i += 1
        cnt += np.abs(y - yp)
    print(cnt / i, cnt, i)
    
    val_N_values = x_data_val[:, 0]
    val_D_values = x_data_val[:, 2]
    val_D_scaled = (60 + 80 * (val_D_values - D_min) / (D_max - D_min)) * 2
    
    i = 0
    cnt = 0
    for n, d, g, y, yp in zip(val_N_values, val_D_scaled, x_data_val[:, 3], y_data_val, y_pred_val):
        plt.scatter(y, yp, marker='*', color='#FF3333', s=d,
                alpha=1.0, edgecolors='w', linewidth=0.05)
        
        i += 1
        cnt += np.abs(y - yp)
    print(cnt / i, cnt, i)
    
    all_y = np.concatenate([y_data, y_data_val, y_pred, y_pred_val])
    global_min = np.min(all_y)
    global_max = np.max(all_y)
    
    margin = (global_max - global_min) * 0.05
    axis_start = global_min - margin
    axis_end = global_max + margin
    
    x_smooth = np.linspace(axis_start, axis_end, 100)
    plt.plot(x_smooth, x_smooth, 'k--', linewidth=2, label='y = x')

    plt.xlim(axis_start, axis_end)
    plt.ylim(axis_start, axis_end)
    plt.gca().set_aspect('equal', adjustable='box')
    
    handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, 
                          markersize=10, label=f'N={int(round(n/1e6))}M') 
               for n, color in N_color_map.items()]
    handles.append(plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='gray', 
                             markersize=10, label='Training'))
    handles.append(plt.Line2D([0], [0], marker='*', color='w', markerfacecolor='#FF3333', 
                             markersize=15, label='Validation'))
    
    plt.legend(handles=handles, loc='center left', bbox_to_anchor=(1, 0.5), fontsize=14)

    plt.xlabel('True Loss', fontsize=18)
    plt.ylabel('Pred Loss', fontsize=18)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)

    plt.tight_layout(rect=[0, 0, 0.85, 1])
    plt.savefig(
        output_path,
        dpi=600,              
        bbox_inches='tight',   
        pad_inches=0.1,         
    )
    plt.close()
    
if __name__ == "__main__":
    points = pd.read_csv('points_fitting.csv')
    points = points.to_numpy()

    points_val = pd.read_csv('points_val.csv')
    points_val = points_val.to_numpy()

    fit_and_plot(points, points_val, output_path='./ours.pdf')
    