# Usage
# python3 plot_consistency.py
# Generates fig:stg_and_consistency


import numpy as np
import re
import matplotlib.pyplot as plt
import seaborn as sns
from IPython import embed
import argparse
import ast
import pandas as pd

from collections import defaultdict
from matplotlib.patches import Patch
from matplotlib.patches import Polygon



from matplotlib.ticker import FuncFormatter
import matplotlib.colors as mcolors


parser = argparse.ArgumentParser(description='Plot graphs')
args = parser.parse_args()

plt.rcParams.update({'font.size': 14})


def generate_gradient_colors(base_color, light_color, number_of_colors):
    return [mcolors.to_hex(c) for c in mcolors.LinearSegmentedColormap.from_list("", [light_color, base_color])(np.linspace(0, 1, number_of_colors))]


def main():

    base_sg_decoding, base_commutative, base_flash = extract_data("ca00221d-07e4-402c-a78a-bf5c740a5535", "plots/582M_1.log", 7)
    small_sg_decoding, small_commutative, small_flash = extract_data("9aa66935-c664-4dbb-96f6-c1a0ac464325", "plots/300M_2.log", 9)

    # Prepare data for plotting
    labels = ['Fast Addition Only', '+ Simplify-and-Guess', '+ Commutativity Check']
    base_values = [1. - base_flash, 1. - base_sg_decoding, 1. - base_commutative]
    small_values = [1. - small_flash, 1. - small_sg_decoding, 1. - small_commutative]

    # Create two separate DataFrames for "Base" and "Small" groups
    df_base = pd.DataFrame({
        'Labels': labels,
        'Values': base_values,
        'Model': ['582M'] * len(labels)
    })

    df_small = pd.DataFrame({
        'Labels': labels,
        'Values': small_values,
        'Model': ['300M'] * len(labels)
    })

    # Concatenate these DataFrames to make one DataFrame
    df_combined = pd.concat([df_small, df_base])

    # Generate gradient colors
    gradient_colors_base = generate_gradient_colors("darkorange", "peachpuff", 3)
    gradient_colors_small = generate_gradient_colors("purple", "plum", 3)

    # Plotting the grouped bar charts
    plt.figure(figsize=(12, 6))
    plt.yscale('log')
    ax = sns.barplot(x='Model', y='Values', hue='Labels', data=df_combined)

    # Sort patches by their x-coordinate position
    sorted_patches = sorted(ax.patches, key=lambda patch: (patch.get_x(), patch.get_height()))

    # Manually setting the colors for each sorted bar
    colors = gradient_colors_small * len(df_small['Model'].unique()) + gradient_colors_base * len(df_base['Model'].unique())
    for patch, color in zip(sorted_patches, colors):
        patch.set_facecolor(color) 

    # Increase font sizes
    plt.title('Error Rate When Generating N+1 Digit Addition', fontsize=20)
    plt.xlabel('Model Size', fontsize=18)
    plt.ylabel('Error Rate (log scale)', fontsize=18)

    # Increase tick label size
    ax.tick_params(axis='both', labelsize=16)

    # Customize y-ticks and legend
    yticks = [1.0, 0.1, 0.01, 0.001]
    plt.yticks(yticks, [str(i) for i in yticks])
    plt.ylim(top=1.0) 

    # Custom legend handler for diagonal split squares
    class GradientDiagonalLegendHandler:
        def __init__(self, color1, color2):
            self.color1 = color1
            self.color2 = color2

        def legend_artist(self, legend, orig_handle, fontsize, handlebox):
            print(self.color1)
            x0, y0 = handlebox.xdescent, handlebox.ydescent
            width, height = handlebox.width, handlebox.height
            patch1 = Polygon([[x0, y0], [x0 + width, y0], [x0, y0 + height]], closed=True, facecolor=self.color1, edgecolor='none')
            patch2 = Polygon([[x0, y0 + height], [x0 + width, y0], [x0 + width, y0 + height]], closed=True, facecolor=self.color2, edgecolor='none')
            handlebox.add_artist(patch1)
            handlebox.add_artist(patch2)
            return patch1, patch2

    # Create custom legend
    legend_elements = list(range(len(labels)))
    handler_map = {i: GradientDiagonalLegendHandler(color1, color2) 
                for i, (color1, color2) in enumerate(zip(gradient_colors_small, gradient_colors_base))}
    plt.legend(legend_elements, labels, title='Labels', handler_map=handler_map, loc='upper right', fontsize=14, title_fontsize='16')


    # Label each bar with its value, with increased font size
    for p in ax.patches:
        ax.annotate(f"{p.get_height():.4f}", 
                    (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha='center', 
                    va='center', 
                    fontsize=16,  # Increased font size
                    color='black', 
                    xytext=(0, 10),
                    textcoords='offset points')

    plt.tight_layout() 
    plt.savefig("plots/final_plots/consistency_all.png") 
    plt.savefig("plots/final_plots/consistency_all.pdf", dpi=1200) 

def extract_data(uuid, logfile, self_train_begin):
    
    generalization_accs = defaultdict(float)
    with open("plots/{}_flash.txt".format(uuid), "r") as f:
        for i, line in enumerate(f):
            acc_array = ast.literal_eval(line)
            generalization_accs[i+3] = acc_array[i+3]
            # generalization_accs.append((i+3, acc_array[i+3]))
    
    
    def extract_specific_lines(lines, keyword="digit decomp problems, we accepted"):
        return [line for line in lines if keyword in line]

    with open(logfile, 'r') as f:
        lines = f.readlines()
        processed_lines = extract_specific_lines(lines)
    
    commutative_right = defaultdict(float)
    commutative_wrong = defaultdict(float)

    simplify_and_guess_right = defaultdict(float)
    simplify_and_guess_wrong = defaultdict(float)

    for line in processed_lines:
        numbers = re.findall(r'\d+', line)
        numbers = [int(num) for num in numbers]

        commutative_right[numbers[0]] += numbers[1]
        commutative_wrong[numbers[0]] += numbers[2]

        simplify_and_guess_right[numbers[0]] += numbers[1] + numbers[3]
        simplify_and_guess_wrong[numbers[0]] += numbers[2] + numbers[4]
    
    max_length = max(commutative_right.keys()) + 1
    with_simplify_and_guess = [simplify_and_guess_right[i] / (simplify_and_guess_right[i] + simplify_and_guess_wrong[i]) for i in range(self_train_begin, max_length)]
    with_commutative = [commutative_right[i] / (commutative_right[i] + commutative_wrong[i]) for i in range(self_train_begin, max_length)]
    flash_guesses = [generalization_accs[i] for i in range(self_train_begin, max_length)]

    return np.mean(with_simplify_and_guess), np.mean(with_commutative), np.mean(flash_guesses)



if __name__ == "__main__":
    main()