# python3 plot_kde_subplot.py ca00221d-07e4-402c-a78a-bf5c740a5535 --self_train_start=7 
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from IPython import embed

# External Modules
import os
import argparse
import ast
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
from matplotlib.patches import Polygon
from matplotlib.patches import Patch





from IPython import embed
import numpy as np
from matplotlib.lines import Line2D

import os
import socket
import argparse
import seaborn as sns

host_name = socket.gethostname()

argument_parser = argparse.ArgumentParser()
argument_parser.add_argument("uuid", type=str, help="UUID of the model to self train")
argument_parser.add_argument("--self_train_start", type=int, default=-1, help="Number of digits to start self training on")
args = argument_parser.parse_args()

def clamp(val, minimum=0, maximum=1):
    return max(min(val, maximum), minimum)

def main():

    assert args.self_train_start != -1, "Must specify self_train_start"
    flash_accs, think_accs = read_data()
    
    plot_pdf_data(flash_accs, think_accs)

def read_data():

    with open("plots/{}_flash.txt".format(args.uuid), "r") as f:
        flash_accs = f.readlines()
        flash_accs = [ast.literal_eval(accs) for accs in flash_accs]
    
    with open("plots/{}_think.txt".format(args.uuid), "r") as f:
        think_accs = f.readlines()
        think_accs = [ast.literal_eval(accs) for accs in think_accs]
    
    return flash_accs, think_accs

def pad(values, length):
    return values + [0] * (length - len(values))

def strip_trailing_zeros(data):
    stripped_data = []
    for lst in data:
        last_value = None
        while lst and lst[-1] < 1e-5:
            last_value = lst.pop()
        
        if last_value is not None:
            lst.append(last_value)
        
        stripped_data.append(lst)
    return stripped_data

def plot_pdf_data(flash_accs, think_accs):
    labels = list(range(3, len(flash_accs) + 3))

    for i, arr in enumerate(think_accs):
        if arr[-1] > 1e-5:
            think_accs[i].append(0)
    
    # flash_accs = strip_trailing_zeros(flash_accs)
    
    # Create a list of data for plotting
    flash_data, think_data = [], []
    for label, flash, think in zip(labels, flash_accs, think_accs):
        for x, y in enumerate(flash):
            flash_data.append([label, x+1, y])
        for x, y in enumerate(think):
            think_data.append([label, x+1, y])
            
    flash_df = pd.DataFrame(flash_data, columns=['Number of Digits', 'X', 'PDF_Value'])
    think_df = pd.DataFrame(think_data, columns=['Number of Digits', 'X', 'PDF_Value'])

    # Initialize the FacetGrid object with updated aesthetics
    pal = sns.color_palette("husl", len(labels))
    g = sns.FacetGrid(think_df, row="Number of Digits", hue="Number of Digits", aspect=50, height=0.2, palette=pal)

    # Initialize variables to store the last and first coordinates and axes
    last_x, last_y, last_ax = None, None, None
    first_x, first_y, first_ax = None, None, None

    for i, label in enumerate(labels):
        think_subset = think_df[(think_df['Number of Digits'] == label) & (think_df['X'] >= label)]
        flash_subset = flash_df[(flash_df['Number of Digits'] == label) & (flash_df['X'] >= label)]
        ax = g.facet_axis(labels[-1] - label, 0)
        
        # sns.lineplot(data=subset, x='X', y='PDF_Value', ax=ax, linewidth=1)

        if i > 0:
            ax.fill_between(think_subset['X'], think_subset['PDF_Value'], color='red', edgecolor='none')
            ax.fill_between(flash_subset['X'], flash_subset['PDF_Value'], color='blue', edgecolor='none')
        last_curve_point = flash_subset.iloc[-1]



        # Highlight the point where X=label in orange
        point_to_highlight = flash_subset[flash_subset['X'] == label]
        
        # If there's a point to highlight, get its coordinates
        if not point_to_highlight.empty:
            x_val = point_to_highlight.iloc[0]['X']
            y_val = 1.0
            
            # Store the first point and axes
            if first_x is None and first_y is None and first_ax is None:
                first_x, first_y, first_ax = x_val, y_val, ax
            
            # If there was a last point, draw a connection patch from it to this point
            if last_x is not None and last_y is not None and last_ax is not None:
                con = ConnectionPatch(xyA=(last_x, last_y), xyB=(x_val, y_val), 
                                      coordsA="data", coordsB="data", 
                                      axesA=last_ax, axesB=ax,
                                      color="orange", linestyle="--", linewidth=2)
                last_ax.add_artist(con)


            if last_x is not None and last_y is not None and last_ax is not None:
                # poly_coords = [(last_x, last_y), (last_x, 0), (x_val, y_val)]
                poly_coords = [(last_x, 0), (x_val, 0), (x_val, y_val)]
                # poly_coords = [(last_curve_point['X'], last_curve_point['PDF_Value']), (last_x, 0), (x_val, y_val)]
                # poly_coords = [(last_x, 0), (last_x, last_y), (x_val, y_val), (x_val, 0)]



                poly = Polygon(poly_coords, closed=True, facecolor='blue', edgecolor='none')
                ax.add_patch(poly)
            # Update the last point and axes
            last_x, last_y, last_ax = x_val, y_val, ax

    # Set the subplots to overlap and adjust layout
    g.figure.subplots_adjust(hspace=0)
    g.set(xlim=(2 , max([len(l) for l in think_accs])-1), ylim=(0, 1.1))
    g.set(yticklabels=[])
    
    g.set(yticks=[])
    g.set_axis_labels("", "").set(yticks=[]).despine(left=True)
    g.fig.text(0.04, 0.5, 'Accuracy (per subplot)', va='center', rotation='vertical')
    g.set_axis_labels("Number of Digits", "").set(yticks=[]).despine(left=True)



    # Remove axes details that don't play well with overlap
    g.set_titles("")
    # g.despine(bottom=False, left=False)

 
        # Define and use a simple function to label the plot in axes coordinates
    def label(data, color, label):
        ax = plt.gca()
        reversed_label = labels[-1] - int(label) + labels[0]  # Assuming 'labels' is your list of labels

        if reversed_label == 3:
            return


        ax.text(0, 0.2, " " + str(reversed_label), fontweight="bold", color="black",
                ha="left", va="center", transform=ax.transAxes, fontsize=8)

    # Use the map function to apply the label to each subplot
    # This assumes that 'Number of Digits' is a column in your DataFrame
    g.map(label, "Number of Digits")

       # Custom legend elements
    legend_elements = [Line2D([0], [0], color='orange', lw=2, linestyle='--', label='Generalization Line'),
                    Patch(facecolor='blue', edgecolor='none', label='Generalization w/o CoT'),
                    Patch(facecolor='red', edgecolor='none', label='Generalization w/ CoT')]

    # Add the legend to the figure
    last_ax = g.axes[-1, 0]  # Last subplot
    last_ax.legend(handles=legend_elements, loc='lower right', title='Legend')
    g.fig.suptitle('Length Generalization (582M)', fontsize=16)

    g.despine(left=True, bottom=True)  
    g.despine(right=True, top=True)
    # Optionally, you can add this line to adjust the figure so the legend fits
    plt.tight_layout(rect=[0, 0, 0.85, 1])


    plt.savefig("plots/final_plots/582M_length_generalization.pdf", dpi=1000)
    plt.savefig("plots/final_plots/582M_length_generalization.png", dpi=1000)



if __name__ == "__main__":
    main()