import re
import wandb 
import colorsys
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np


entity = "coder66-lab"
project = "diagonal-net-loss-trend"

valid_d, valid_k, valid_delta = 10000, 50, 0.5
valid_n = list(range(300, 500, 10)) + list(range(500, 1050, 50))
valid_lr = [1e-2]
valid_beta2 = [0.999]
valid_beta1 = [0.9]
valid_exponent = [0.001, 0.25, 0.75]

df = pd.read_csv("results_AdamE.csv")

purple_base = "#C68EFD"
orange_base = "#FDD400"  
# purple_base = "#8D9FBD"
# orange_base = "#E6AE7F"

def generate_color_range(
    base_color: str, 
    num_colors: int, 
    saturation_range=(0.8, 1.2), 
    value_range=(1.0, 1.2), 
    hue_shift=0.1
) -> list[str]:
    """
    Generate a list of `num_colors` hex colors 'around' the given base_color,
    by shifting hue ±hue_shift and linearly spacing saturation and value 
    within the provided ranges.
    """
    # simple linspace implementation
    def linspace(a, b, n):
        if n == 1:
            return [a]
        step = (b - a) / (n - 1)
        return [a + step * i for i in range(n)]
    
    # strip '#' and convert hex to float RGB in [0,1]
    hexstr = base_color.lstrip('#')
    r = int(hexstr[0:2], 16) / 255.0
    g = int(hexstr[2:4], 16) / 255.0
    b = int(hexstr[4:6], 16) / 255.0
    
    # convert to HSV
    base_h, base_s, base_v = colorsys.rgb_to_hsv(r, g, b)
    print(f"base_h: {base_h}, base_s: {base_s}, base_v: {base_v}")

    # generate arrays of shifted H, spaced S and V
    h_vals = [ (base_h + x) % 1.0 for x in linspace(-hue_shift, hue_shift, num_colors) ]
    s_vals = linspace(saturation_range[0] * base_s, saturation_range[1] * base_s, num_colors)
    v_vals = linspace(value_range[0] * base_v, value_range[1] * base_v, num_colors)
    
    # convert back to hex
    output_colors = []
    for h, s, v in zip(h_vals, s_vals, v_vals):
        rr, gg, bb = colorsys.hsv_to_rgb(h, s, v)
        output_colors.append("#{0:02x}{1:02x}{2:02x}".format(
            min(int(rr * 255), 255),
            min(int(gg * 255), 255),    
            min(int(bb * 255), 255)
        ))
    
    return output_colors

# purple_colors = generate_color_range(purple_base, 4, hue_shift=0.04)
# orange_colors = generate_color_range(orange_base, 8, hue_shift=0.04)
# purple_colors = ["#694F8E", "#B692C2", "#E3A5C7", "#BB9AB1"] 
# purple_colors = ["#80CBA4", "#A30543", "#FBDA83", "#E9F4A3"] 
# purple_colors = ["#8F87F1", "#FF9A00", "#A1E3F9", "#96CEB4"] 
purple_colors = ["#33BBC5", "#FF9A00", "#E4B1F0", "#85E6C5"] 
orange_color = orange_base

plt.figure(figsize=(6, 4))

AdamE_index = 0
groups = df.groupby(["opt", "lr", "beta1", "beta2", "exponent"])
for config, grp in groups:
    grp = grp.sort_values(by="n")
    grp = grp[grp["n"].isin(valid_n)]    

    opt, lr, beta1, beta2, exponent = config
    if exponent == 0.5:
        color = orange_color
    else:
        color = purple_colors[AdamE_index]
        AdamE_index += 1
    config_name = f"AdamE_lambda={exponent}"
    # config_name = f"AdamE_b1={beta1}_b2={beta2}_lr={lr}_exp={exponent}"
    print(f"Plotting {config_name} ...")
    # when n increases, at the first time loss < 1, cut it
    n_values = grp["n"].values
    loss_values = grp["loss"].values
    for i in range(len(n_values)):
        if grp["loss"].values[i] < 1:
            n_values = n_values[:i+1]
            loss_values = grp["loss"].values[:i+1]
            break
    plt.plot(n_values, loss_values, marker='o', label=config_name, color=color, markersize=4)


###### plotting sgd and adam

df = pd.read_csv("results.csv")
groups = df.groupby(["opt", "lr", "beta1", "beta2"])
for config, grp in groups:
    # filter lr = 1e-2
    if config[0] == "Adam" and config[1] != 1e-2:
        continue
    if config[0] == "SGD" and config[1] != 1e-2:
        continue
    if config[0] == "Adam" and config[3] != 0.999:
        continue
    grp = grp.sort_values(by="n")
    grp = grp[grp["n"].isin(valid_n)]    

    opt, lr, beta1, beta2 = config
    if opt == "SGD":
        color = orange_base
        config_name = f"SGD"
    else:
        color = purple_base
        config_name = f"Adam"
    
    print(f"Plotting {config_name} ...")
    n_values = grp["n"].values
    loss_values = grp["loss"].values
    for i in range(len(n_values)):
        if grp["loss"].values[i] < 1:
            n_values = n_values[:i+1]
            loss_values = grp["loss"].values[:i+1]
            break
    plt.plot(n_values, loss_values, marker='o', label=config_name, color=color, markersize=4)
    # plt.yscale("log")
    plt.yscale("symlog", linthresh=1)
    plt.xlabel("n_train")
    plt.ylabel("Loss")

# plot a horizontal line at y=1
plt.axhline(y=1, color='gray', linestyle='--', label="Loss=1")

## ? under y = 1, use linear scale; above y = 1, use log scale

# plt.yscale("log")
plt.xlabel("n_train")
plt.ylabel("Loss")
plt.title(f"Loss vs n_train of AdamE, Adam and SGD")
# with d={valid_d}, k={valid_k}, delta={valid_delta}")

handles, labels = plt.gca().get_legend_handles_labels()
desired_order = [
    'AdamE_lambda=0.25',
    'Adam', 
    'AdamE_lambda=0.75',
    'AdamE_lambda=0.9',
    'AdamE_lambda=0.001',
    'SGD'
]
new_handles = [handles[labels.index(l)] for l in desired_order]
new_labels  = desired_order

plt.legend(new_handles, new_labels, loc='lower center', fontsize=8.5)
plt.savefig("loss_vs_n_train_adame.png")