import matplotlib.pyplot as plt
import seaborn as sns
import argparse

parser = argparse.ArgumentParser(description='Plot graphs')
parser.add_argument('--name', type=str, default="processed_582M_1.log", help='name of the processed log file')
parser.add_argument('--self_train_begin', type=int, default=7, help='the number of digits where self-training begins')
args = parser.parse_args()

def enhanced_plot_graph_for_publication(filename, self_train_begin):
    # Set seaborn style for better visualization
    sns.set_style("whitegrid")
    sns.set_context("poster")  # This increases the font size

    
    # Read data from file
    with open(filename, 'r') as f:
        data = f.read()

    # Split the data into two sections based on "BEGINNING SELF-TRAINING"
    pre_training, post_training = data.split("BEGINNING SELF-TRAINING")
    
    # Filter out lines with num_digits above self_train_begin from the pre-training section
    pre_training_lines = [line for line in pre_training.split('\n') if line and not line.startswith("num_digits") and int(line.split(',')[0]) < self_train_begin]

    # Splitting data by lines from the post-training section
    post_training_lines = [line for line in post_training.split('\n') if line and not line.startswith("num_digits")]

    # Adjusting total_steps for post-training data
    running_sum = int(pre_training_lines[-1].split(',')[3])
    for i in range(len(post_training_lines)):
        total_steps_value = int(post_training_lines[i].split(',')[3])
        running_sum += total_steps_value
        post_training_lines[i] = post_training_lines[i].rsplit(',', 1)[0] + ',' + str(running_sum)

    # Combine both filtered sections
    lines = pre_training_lines + post_training_lines

    # Extracting required values for plotting
    supervised_num_digits = [int(line.split(',')[0]) for line in pre_training_lines]
    supervised_total_steps = [int(line.split(',')[3]) for line in pre_training_lines]

    self_num_digits = [int(line.split(',')[0]) for line in post_training_lines]
    self_total_steps = [int(line.split(',')[3]) for line in post_training_lines]

    supervised_num_digits.append(self_num_digits[0])
    supervised_total_steps.append(self_total_steps[0])

    model_size = filename.split('_')[1]
    # Plotting the graph
    plt.figure(figsize=(14, 7))
    plt.plot(self_num_digits, self_total_steps, '-o', color='red', linewidth=2.5, markersize=8, label="Self-training")
    plt.plot(supervised_num_digits, supervised_total_steps, '-o', color='blue', linewidth=2.5, markersize=8, label="Supervised training")
    plt.title("Learning Trajectory of the {} ByT5 Model".format(model_size))
    plt.xlabel("Addition Problem Length")
    plt.ylabel("Total Training Examples Seen")
    # plt.axvline(x=self_train_begin, color='red', linestyle='--', label="Self training begins")
    plt.ylim(0, 1.03 * max(self_total_steps))  # Setting y-axis starting at 0 and providing a little extra space on top
    plt.xticks(range(min(supervised_num_digits), max(self_num_digits) + 1, 1))

    plt.legend(loc='lower right')
    plt.tight_layout()

    model_size_and_version = filename.split('_')[1] + '_' + filename.split('_')[2].split('.')[0]
    plt.savefig("final_plots/learning_trajectory_{}.png".format(model_size_and_version), dpi=300)
    plt.savefig("final_plots/learning_trajectory_{}.pdf".format(model_size_and_version), dpi=600)


if __name__ == "__main__":
    enhanced_plot_graph_for_publication(args.name, args.self_train_begin)