import re
import wandb 
import colorsys
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np


entity = "coder66-lab"
project = "diagonal-net-loss-trend"

valid_d, valid_k, valid_delta = 10000, 50, 0.5
# valid_n = list(range(50, 500, 10)) + list(range(500, 1050, 50))
valid_n = list(range(300, 500, 10)) + list(range(500, 1050, 50))
# valid_lr = [5e-3, 1e-2, 5e-2, 1e-1]
valid_lr = [1e-2, 1e-1]
valid_beta2 = [0.95, 0.999]
valid_beta1 = [0.9]

df = pd.read_csv("results.csv")

## generate eight purple colors and eight orange colors to plot
purple_base = "#8E7DBE"
orange_base = "#FDD400"  

def generate_color_range(
    base_color: str, 
    num_colors: int, 
    saturation_range=(0.8, 1.2), 
    value_range=(1.0, 1.2), 
    hue_shift=0.1
) -> list[str]:
    """
    Generate a list of `num_colors` hex colors 'around' the given base_color,
    by shifting hue ±hue_shift and linearly spacing saturation and value 
    within the provided ranges.
    """
    # simple linspace implementation
    def linspace(a, b, n):
        if n == 1:
            return [a]
        step = (b - a) / (n - 1)
        return [a + step * i for i in range(n)]
    
    # strip '#' and convert hex to float RGB in [0,1]
    hexstr = base_color.lstrip('#')
    r = int(hexstr[0:2], 16) / 255.0
    g = int(hexstr[2:4], 16) / 255.0
    b = int(hexstr[4:6], 16) / 255.0
    
    # convert to HSV
    base_h, base_s, base_v = colorsys.rgb_to_hsv(r, g, b)
    print(f"base_h: {base_h}, base_s: {base_s}, base_v: {base_v}")

    # generate arrays of shifted H, spaced S and V
    h_vals = [ (base_h + x) % 1.0 for x in linspace(-hue_shift, hue_shift, num_colors) ]
    s_vals = linspace(saturation_range[0] * base_s, saturation_range[1] * base_s, num_colors)
    v_vals = linspace(value_range[0] * base_v, value_range[1] * base_v, num_colors)
    
    # convert back to hex
    output_colors = []
    for h, s, v in zip(h_vals, s_vals, v_vals):
        rr, gg, bb = colorsys.hsv_to_rgb(h, s, v)
        output_colors.append("#{0:02x}{1:02x}{2:02x}".format(
            min(int(rr * 255), 255),
            min(int(gg * 255), 255),    
            min(int(bb * 255), 255)
        ))
    
    return output_colors

purple_colors = generate_color_range(purple_base, 8, hue_shift=0.04)
orange_colors = generate_color_range(orange_base, 8, hue_shift=0.04)

## plot the results
plt.figure(figsize=(12, 8))
sgd_index, adam_index = 0, 0

groups = df.groupby(["opt", "lr", "beta1", "beta2"])
plt.axhline(y=np.sqrt(valid_k), color="#DA498D", linestyle='--', label="sqrt(k)")

for config, grp in groups:
    grp = grp.sort_values(by="n")

    opt, lr, beta1, beta2 = config
    if opt == "SGD":
        color = orange_colors[sgd_index]
        sgd_index += 1
        config_name = f"SGD_lr={lr}"
    else:
        color = purple_colors[adam_index]
        adam_index += 1
        config_name = f"Adam_b1={beta1}_b2={beta2}_lr={lr}"
    
    print(f"Plotting {config_name} ...")
    n_values = grp["n"].values
    norm_values = grp["total_norm"].values
    plt.plot(n_values, norm_values, marker='o', label=config_name, color=color)
    # plt.yscale("log")
    # draw a horizontal line at y=sqrt{50}
    plt.xlabel("n_train")
    plt.ylabel("Norm")

plt.title(f"Norm vs n_train of SGD & Adam with d={valid_d}, k={valid_k}, delta={valid_delta}")
plt.legend()
plt.savefig("norm_vs_n_train_compare.png")
plt.close()