import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import argparse
import pandas as pd
from matplotlib.markers import MarkerStyle
from matplotlib.patches import FancyArrowPatch


def get_rows_where_col_equals(df, col, value):
    return df.loc[df[col] == value].copy()

def get_rows_where_col_in(df, col, values):
    return df.loc[df[col].isin(values)].copy()

def get_rows_where_col_greater(df, col, value):
    return df.loc[df[col] > value].copy()

def get_rows_where_col_less(df, col, value):
    return df.loc[df[col] < value].copy()

# Function to get a point partway between two other points
def intermediate_point(p1, p2, fraction=0.5):
    return p1[0] + fraction * (p2[0] - p1[0]), p1[1] + fraction * (p2[1] - p1[1])

def compute_isoclines(delta, normx, y, lr, n_eff):
    delta_iso = lr * normx**2 * delta * (delta+y) / n_eff
    lmbda_iso = 4 * (delta + y) / (lr * delta)
    return delta_iso, lmbda_iso

def compute_forbidden(delta, y, n_eff):
    frbddn = np.abs(2*(delta+y)/np.sqrt(n_eff))
    return frbddn

def compute_sharpness(delta, lr, y):
    return 4*(delta+y) / (lr*delta)

def f(delta, lmbda, lr, x, y, n):
    return delta * ( -lr * lmbda + lr**2 * x**2 * delta * (delta + y) / n)

def g(delta, lmbda, lr, x, y, n):
    return (-4* lr * x**2 * delta * (delta + y) / n) + (lr**2 * x**2 * delta**2 * lmbda / n)

def plot_flow(config, df):

    normx = config.normx
    y = config.ymean
    n_eff = config.eff_width
    
    df_init = get_rows_where_col_equals(df, 'step', 1.0)
    lr = df['lr'].unique().squeeze()

    fig, ax = plt.subplots(1, 1, figsize = (6, 4))

    # Saddle and unstable Fixed points

    fixed_points = [(-y, 0), (2*np.sqrt(n_eff) / (normx*lr), 4/lr + 2*normx*y/np.sqrt(n_eff) ), (-2*np.sqrt(n_eff) / (normx*lr), 4/lr - 2*normx*y/np.sqrt(n_eff) ),  ]
    labels = ['II', 'IV', 'III']
    offset = 0.5
    # Plotting the fixed points
    for i, fixed_point in enumerate(fixed_points):
        a, b = fixed_point
        if i == 0:  # Saddle point
            # Vertically half filled circle with left half filled
            ax.scatter(a, b, color='black', edgecolor='black', s=120, marker='o', zorder=2)
            ax.annotate(labels[i], (a, b), textcoords="offset points", xytext=(20,-5), ha='center')
        else:  # Unstable fixed points
            ax.scatter(a, b, color='black', edgecolor='black', s=120, marker='o', zorder=2)
            if i == 1:
                ax.annotate(labels[i], (a, b), textcoords="offset points", xytext=(-5, 15), ha='center')
            else:
                ax.annotate(labels[i], (a, b), textcoords="offset points", xytext=(5, 15), ha='center')
        
        
    delta_max = 2*np.sqrt(n_eff) / (normx*lr)
    delta_max += delta_max / 20
    delta_min = -delta_max

    lmbda_max =  4/lr + 2*normx*y/np.sqrt(n_eff) 
    lmbda_max += lmbda_max / 8
    lmbda_min = -lmbda_max / 40

    ###############################
    # Plot the zero loss fixed line
    ###############################
    ymin =  2*normx*y / np.sqrt(n_eff)
    ymax = 2 / lr
    # Check condition and plot if ymin is not greater than ymax
    if ymin <= ymax:
        ax.plot([0, 0], [ymin, ymax], color='black', linestyle='-', linewidth = 2.0)
    ax.annotate('I', (0, fixed_points[1][1]), textcoords="offset points", xytext=(-10, 0), ha='center')
    #ax.annotate('I', (0, 1.5*ymax), textcoords="offset points", xytext=(-10, 0), ha='center')


    # Plot dashed line from ymax_value to top of your data range (adjust as necessary)
    ax.plot([0, 0], [np.maximum(ymax, ymin), lmbda_max], color='black', linestyle='--', linewidth = 2.0)
    ax.axhline(y = 2 / lr, linestyle = 'dashdot', color = 'black')

    colors = sns.color_palette('colorblind', 20)

    ###########################
    # Plot the forbidden region
    ###########################


    # Define the range for your x-axis variable, delta
    delta = np.linspace(delta_min, delta_max+0.1, 1000)  # adjust start_value and end_value as necessary

    # Calculate the upper boundary for the forbidden region based on the inequality
    lmbda_boundary = 2 * normx * np.abs(delta + y) / np.sqrt(n_eff)

    # Fill the forbidden region
    ax.fill_between(delta, ax.get_ylim()[0], lmbda_boundary, where=(lmbda_boundary <= ax.get_ylim()[1]), color= colors[7], alpha=0.5, label = 'Forbidden')



    ###########################
    # PLOT Divergent phase
    ###########################

    # plot the divergent phase
    slope = lr * normx**2 * y / n_eff
    intercept = 4 / lr
    delta_div = np.arange(delta_min, delta_max, 0.001)
    frbddn_div = compute_forbidden(delta_div, y, n_eff)
    div = slope*delta_div+intercept
    ax.fill_between(delta_div, lmbda_max, np.maximum(div, frbddn_div), color = colors[0], alpha = 0.5, label = 'Divergent') # light red for divergence


    #####################################
    ### compute sharpness reduction phase
    #####################################
    delta_minus = np.arange(fixed_points[2][0], -y, 0.001)
    #print(delta_minus)
    sharpness_minus = compute_sharpness(delta_minus, lr, y)
    frbddn_minus = compute_forbidden(delta_minus, y, n_eff)
    #print(sharpness_minus, frbddn_minus)
    ax.fill_between(delta_minus, frbddn_minus, sharpness_minus, color=colors[2], alpha = 0.5, label = 'Sharpness reduction') # light green for sharpness reduction

    delta_redc = np.arange(0.0, fixed_points[1][0]+0.001, 0.001)
    frbddn_redc = compute_forbidden(delta_redc, y, n_eff)
    div_redc = slope*delta_redc+intercept
    ax.fill_between(delta_redc, frbddn_redc, div_redc, color = colors[2], alpha = 0.5) # light green for sharpness reduction

    #####################################
    # sharpness increase phases
    #####################################
    delta_prg = np.arange(fixed_points[2][0], 0.0, 0.001)
    div_prg = slope*delta_prg+intercept # divergent line
    sharpness_prg = compute_sharpness(delta_prg, lr, y)
    frbddn_prg = compute_forbidden(delta_prg, y, n_eff)
    ax.fill_between(delta_prg, div_prg, np.maximum(sharpness_prg, frbddn_prg), color = colors[1], alpha = 0.5, label = 'Progressive sharpening') # light yellow for progressive sharpening

    ###########################
    # PLOT NULLCLINES
    ###########################

    ### plot isoclines
    colors = sns.color_palette('tab10', 10)
    delta_plus = np.arange(0.01, delta_max+0.01, 0.01)
    delta_iso, lmbda_iso = compute_isoclines(delta_plus, normx, y, lr, n_eff)
    ax.plot(delta_plus, delta_iso, linestyle = '--', color = colors[1], zorder = 1, alpha = 1, linewidth = 2)
    ax.plot(delta_plus, lmbda_iso, linestyle = '--', color = 'white', zorder = 1, linewidth = 2)

    delta_minus = np.arange(delta_min, 0.0, 0.01)
    delta_iso, lmbda_iso = compute_isoclines(delta_minus, normx, y, lr, n_eff)
    ax.plot(delta_minus, delta_iso, linestyle = '--', color = colors[1], zorder = 1, alpha = 1)
    ax.plot(delta_minus, lmbda_iso, linestyle = '--', color = 'white', zorder = 1, linewidth = 2)

    ###########################
    # Plot the eigenvectors
    ###########################
    arrowalpha = 0.25
    arrowcolor = 'black'
    # 1. First eigenvector: also the eos manifold
    # Calculate the end point based on the eigenvector direction
    x_end = delta_max
    y_end = 2 * normx * (delta_max + y) / np.sqrt(n_eff)

    #eps for perturbing
    eps = 0.1
    mid_x = (-y + x_end) / 2
    mid_y = y_end / 2

    start_x = mid_x - eps * (x_end + y) / 2
    start_y = mid_y - eps * y_end / 2

    end_x = mid_x + eps * (x_end + y) / 2
    end_y = mid_y + eps * y_end / 2

    # Create an arrow patch between the adjusted start and end points around the midpoint
    arrow = FancyArrowPatch((start_x, start_y), (end_x, end_y), mutation_scale=15, color=arrowcolor, arrowstyle='-|>', alpha = arrowalpha)
    ax.add_patch(arrow)

    # Plot the eigenvector line from the fixed point to the calculated end point
    ax.plot([-y, x_end], [0, y_end], color=arrowcolor, linestyle='-', linewidth=1.5, alpha = arrowalpha)

    # 2. Second eigenvector
    x_start = delta_min
    y_start = -2 * normx * (delta_min + y) / np.sqrt(n_eff)

    # Compute midpoint and positions slightly before and after the midpoint
    eps = 0.1
    mid_x = (x_start - y) / 2
    mid_y = y_start / 2

    start_x = mid_x - eps * (x_start + y) / 2
    start_y = mid_y - eps * y_start / 2

    end_x = mid_x + eps * (x_start + y) / 2
    end_y = mid_y + eps * y_start / 2

    # Create an arrow patch between the adjusted start and end points around the midpoint
    arrow = FancyArrowPatch((end_x, end_y), (start_x, start_y), mutation_scale=15, color=arrowcolor, arrowstyle='-|>', alpha = arrowalpha)
    ax.add_patch(arrow)

    # Plot the eigenvector line from the calculated start point to the fixed point
    ax.plot([x_start, -y], [y_start, 0], color=arrowcolor, linestyle='-', linewidth=1.5, alpha = arrowalpha)


    # 3. Third eigenvector (i)
    x_start = delta_min
    y_start = delta_min * lr * normx**2 * y / n_eff + 4 / lr

    # Calculate midpoint and positions slightly before and after the midpoint for arrow placement
    mid_x = x_start / 2
    mid_y = (y_start + 4/lr) / 2

    start_x = mid_x - 0.1 * (x_start) / 2
    start_y = mid_y - 0.1 * (y_start - 4/lr) / 2

    end_x = mid_x + 0.1 * (x_start) / 2
    end_y = mid_y + 0.1 * (y_start - 4/lr) / 2

    # Add arrow pointing from midpoint to starting point
    arrow_i = FancyArrowPatch((end_x, end_y), (start_x, start_y), mutation_scale=15, color=arrowcolor, arrowstyle='-|>', alpha = arrowalpha)
    ax.add_patch(arrow_i)

    # Plot the line
    ax.plot([x_start, 0], [y_start, 4/lr], color=arrowcolor, linestyle='-', linewidth=1.5, alpha = arrowalpha)


    # 3. Third eigenvector (ii)
    x_end = delta_max
    y_end = delta_max * lr * normx**2 * y / n_eff + 4 / lr

    # Calculate midpoint and positions slightly before and after the midpoint for arrow placement
    mid_x = x_end / 2
    mid_y = (y_end + 4/lr) / 2

    start_x = mid_x - 0.1 * (x_end) / 2
    start_y = mid_y - 0.1 * (y_end - 4/lr) / 2

    end_x = mid_x + 0.1 * (x_end) / 2
    end_y = mid_y + 0.1 * (y_end - 4/lr) / 2

    # Add arrow pointing from midpoint to end point
    arrow_ii = FancyArrowPatch((end_x, end_y), (start_x, start_y), mutation_scale=15, color=arrowcolor, arrowstyle='-|>', alpha = arrowalpha)
    ax.add_patch(arrow_ii)

    # Plot the line
    ax.plot([0, x_end], [4/lr, y_end], color=arrowcolor, linestyle='-', linewidth=1.5, alpha = arrowalpha)


    ##########################
    # plot the vector field
    ##########################

    delta_vals = np.linspace(delta_min, delta_max, 10)
    lmbda_vals = np.linspace(lmbda_min, lmbda_max, 10)
    D, L = np.meshgrid(delta_vals, lmbda_vals)

    # Compute the vector fields
    F = f(D, L, lr, normx, y, n_eff)
    G = g(D, L, lr, normx, y, n_eff)

    magnitude = np.sqrt(F**2 + G**2)
    F_unit = F / magnitude
    G_unit = G / magnitude

    # Mask the forbidden region
    mask = L < 2 * normx * np.abs(D + y) / np.sqrt(n_eff)
    F_unit[mask] = np.nan
    G_unit[mask] = np.nan

    ax.quiver(D, L, F_unit, G_unit, angles='xy', scale_units=None, color='black', alpha=0.4, linewidth=0.025, width = 0.005)

    ###########################
    # plot the trajectory
    ###########################

    colors = sns.color_palette('colorblind', 20)
    ax.scatter(df_init['residual_step'], df_init['ntk_step'], marker='*', color='brown', s=200, label = 'Initialization', zorder = 2)

    ax.plot(df['residual_step'], df['ntk_step'], color='black', linewidth=3, label = 'Trajectory', zorder = 1)
    for i in range(1, min(10, len(df))):
        start = intermediate_point((df['residual_step'].iloc[i-1], df['ntk_step'].iloc[i-1]), (df['residual_step'].iloc[i], df['ntk_step'].iloc[i]), fraction=0.45)
        end = intermediate_point((df['residual_step'].iloc[i-1], df['ntk_step'].iloc[i-1]), (df['residual_step'].iloc[i], df['ntk_step'].iloc[i]), fraction=0.55)
        arrow = FancyArrowPatch(start, end, mutation_scale=20, color='black', arrowstyle='-|>')
        ax.add_patch(arrow)

    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize = 10)


    # set x, y lims
    ax.set_xlim(delta_min, delta_max)
    ax.set_ylim(lmbda_min, lmbda_max)
    #put labels
    ax.set_xlabel(f'$\Delta f$')
    ax.set_ylabel(f'$\lambda$')

    # put title for the model parameters
    ax.set_title(f'$n=${config.width}, $n_{{eff}}=${n_eff}, $y=${y:0.1f}, $\eta=${lr:0.2f}')
    #ax.set_title(f'$\eta=${lr:0.2f}')
    plt.show()



