import re
import wandb 
import colorsys
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np


entity = "coder66-lab"
project = "diagonal-net-loss-trend"

valid_d, valid_k, valid_delta = 10000, 50, 0.5
valid_n = list(range(300, 500, 10)) + list(range(500, 1050, 50))
valid_lr = [1e-2]
valid_beta2 = [0.999]
valid_beta1 = [0.9]
valid_exponent = [0.25, 0.5, 0.75, 1]

df = pd.read_csv("results_AdamE.csv")

purple_base = "#8E7DBE"
orange_base = "#FDD400"  

def generate_color_range(
    base_color: str, 
    num_colors: int, 
    saturation_range=(0.8, 1.2), 
    value_range=(1.0, 1.2), 
    hue_shift=0.1
) -> list[str]:
    """
    Generate a list of `num_colors` hex colors 'around' the given base_color,
    by shifting hue ±hue_shift and linearly spacing saturation and value 
    within the provided ranges.
    """
    # simple linspace implementation
    def linspace(a, b, n):
        if n == 1:
            return [a]
        step = (b - a) / (n - 1)
        return [a + step * i for i in range(n)]
    
    # strip '#' and convert hex to float RGB in [0,1]
    hexstr = base_color.lstrip('#')
    r = int(hexstr[0:2], 16) / 255.0
    g = int(hexstr[2:4], 16) / 255.0
    b = int(hexstr[4:6], 16) / 255.0
    
    # convert to HSV
    base_h, base_s, base_v = colorsys.rgb_to_hsv(r, g, b)
    print(f"base_h: {base_h}, base_s: {base_s}, base_v: {base_v}")

    # generate arrays of shifted H, spaced S and V
    h_vals = [ (base_h + x) % 1.0 for x in linspace(-hue_shift, hue_shift, num_colors) ]
    s_vals = linspace(saturation_range[0] * base_s, saturation_range[1] * base_s, num_colors)
    v_vals = linspace(value_range[0] * base_v, value_range[1] * base_v, num_colors)
    
    # convert back to hex
    output_colors = []
    for h, s, v in zip(h_vals, s_vals, v_vals):
        rr, gg, bb = colorsys.hsv_to_rgb(h, s, v)
        output_colors.append("#{0:02x}{1:02x}{2:02x}".format(
            min(int(rr * 255), 255),
            min(int(gg * 255), 255),    
            min(int(bb * 255), 255)
        ))
    
    return output_colors

# purple_colors = generate_color_range(purple_base, 4, hue_shift=0.04)
# orange_colors = generate_color_range(orange_base, 8, hue_shift=0.04)
purple_colors = ["#694F8E", "#B692C2", "#E3A5C7", "#BB9AB1"] 
orange_color = orange_base

plt.figure(figsize=(12, 8))
baseline = np.sqrt(valid_k) 
plt.axhline(y=0, color="#DA498D", linestyle='--', label="sqrt(k)")

AdamE_index = 0
groups = df.groupby(["opt", "lr", "beta1", "beta2", "exponent"])
for config, grp in groups:
    grp = grp.sort_values(by="n")
    opt, lr, beta1, beta2, exponent = config
    if exponent == 0.5:
        color = orange_color
    else:
        color = purple_colors[AdamE_index]
        AdamE_index += 1
    config_name = f"AdamE_b1={beta1}_b2={beta2}_lr={lr}_exp={exponent}"
    print(f"Plotting {config_name} ...")
    n_values = grp["n"].values
    norm_values = grp["total_norm"].values - baseline
    plt.plot(n_values, norm_values, marker='o', label=config_name, color=color)

plt.yscale("symlog", linthresh=1e-1)
plt.xlabel("n_train")
plt.ylabel("Norm")
plt.title(f"Norm vs n_train of AdamE with d={valid_d}, k={valid_k}, delta={valid_delta}")
plt.legend()
plt.savefig("norm_vs_n_train_adame.png")