import os
import argparse

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
import matplotlib.gridspec as gridspec

import numpy as np
import pandas as pd
from fig_atscore_replace import do_plot as do_plot_r

from fig_atscore_curve import color_palette, model_list

is_pdf = True

map_str = {
    'glm': 'GLM4-9B',
    'llama31': 'Llama3.1-8B',
    'qwen': 'Qwen2-72B',
}

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # parser.add_argument('--id', '-i', type=int, default=287, help='System Message ID for the plot')
    parser.add_argument('--layer', '-l', type=int, default=-1, help='Middle(0) / final(1) Layer for the plot, -1 for all')
    parser.add_argument('--ignore_cache', '-c', action='store_true', help='Ignore cache')
    parser.add_argument('--case-str', type=str, default="0,231,25", help=r'Case study control string, formatted as "{layer_idx},{sid},{window_size}", empty for unplotted')
    # parser.add_argument('--window_size', '-w', type=int, default=21, help='Window size for moving average')
    args = parser.parse_args()
    
    if args.case_str:
        try:
            layer_idx, sid, window_size = map(int, args.case_str.split(','))
        except ValueError:
            print('Invalid case string, should be formatted as "{layer_idx},{sid},{window_size}"')
            exit(1)

        fig = plt.figure(figsize=(8, 3.8), dpi=300, tight_layout=True)
        gs_base = gridspec.GridSpec(2, 1, height_ratios=[2, 2], hspace=0.05)
        ax0 = fig.add_subplot(gs_base[1])
        ax0.set_xlim(-1, 101)
        
        from fig_atscore_curve import do_plot as do_plot_c
        kwargs = {
            'plot_sid': sid,
            'layer_idx': layer_idx,
            'window_size': window_size
        }
        do_plot_c(ax0, **kwargs)
        
        gs = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs_base[0], 
                                              width_ratios=[1.8, 1], wspace=0.15)
        
    else:
        fig = plt.figure(figsize=(8, 2.5), dpi=300, tight_layout=True)
        gs = gridspec.GridSpec(1, 2, width_ratios=[1.8, 1], wspace=0.15)
    
    gs_l = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs[0], wspace=0.0)
    ax1 = fig.add_subplot(gs_l[0])
    ax2 = fig.add_subplot(gs_l[1], sharey=ax1, sharex=ax1)
    ax3 = fig.add_subplot(gs[1], sharey=ax1)
    
    plt.setp(ax2.get_yticklabels(), visible=False)
    
    # plt.subplots_adjust(left=0.0, right=0.0, top=0.9, bottom=0.1)
    
    do_plot_r(ax1, plot_replace=False,
              plot_sid=-1, layer_idx=args.layer, round_related=1, ignore_cache=args.ignore_cache)
    do_plot_r(ax2, plot_replace=False,
              plot_sid=-1, layer_idx=args.layer, round_related=0, ignore_cache=args.ignore_cache)
    
    kwargs = {
        'plot_sid': -1,
        'layer_idx': -1,
        'round_related': -1,
        'ignore_cache': args.ignore_cache
    }
    do_plot_r(ax3, **kwargs)
    
    ax1.set_xlim(0.5, 5.5)
    ax1.set_ylim(0, 0.45)
    
    # ax1.text(5.1, 0.38, '(a)', ha='center', va='center', fontsize=7.5)
    ax2.text(5.1, 0.38, '(a)', ha='center', va='center', fontsize=7.5)
    ax3.text(4.95, 0.38, '(b)', ha='center', va='center', fontsize=7.5)
    
    if args.case_str:
        ax0.text(98.6, 0.65, '(c)', ha='center', va='center', fontsize=7.5)
        for ax in [ax1, ax2, ax3]:
            # move x tick labels above the axis
            ax.tick_params(axis='x', length=3, pad=-12, labelsize=8)
            ax.tick_params(axis='y', length=3, labelsize=8)
            
    
    # axs[0].set_ylim(0.0, 0.75)
    # axs[0].set_xlim(-2, 102)
    
    # fig.text(-0.01, 0.52, "System Message's Share of Total AS", va='center', rotation='vertical', fontsize=11)
    patches = [mpatches.Patch(color=color_palette[i], label=map_str[model_list[i]]) for i in range(len(model_list))]
    lines = [
        # mlines.Line2D([], [], color='black', linestyle='-.', label='Average'),
        mlines.Line2D([], [], color='black', linestyle='-', label='As System'),
        mlines.Line2D([], [], color='black', linestyle='--', label='As User'),
    ]
    
    # fig.text(0.55, 0.8, '(a)', ha='center', 
    #          fontdict={'fontsize': 12, 'font': 'Times New Roman'})
    # fig.text(0.95, 0.8, '(b)', ha='center', 
    #          fontdict={'fontsize': 12, 'font': 'Times New Roman'})
    legend = fig.legend(handles=patches + lines, loc='upper center', ncol=5, fontsize=9,
                        bbox_to_anchor=(0.50, 0.952), columnspacing=1.5, labelspacing=0.2,
                        frameon=False,handletextpad=0.8)
    
    
    # fig.text(0.351, 0.828, "System Message's Share of Total Attention Score", 
    #          ha='center', va='center', fontsize=10)
    
    x_fig, y_fig = ax1.transData.transform_point((5.38, 0.42))
    x_fig, y_fig = fig.transFigure.inverted().transform_point((x_fig, y_fig))
    fig.text(x_fig, y_fig, "System Message's Share of Total Attention Score", 
             ha='center', va='center', fontsize=9)
    
    # legend.get_frame().set_alpha(None)
    # legend.get_frame().set_facecolor((0.95, 0.95, 0.95, 0.95))
    
    file_name = 'figures/atscore_v2' + ('.pdf' if is_pdf else '.png')
    plt.savefig(file_name, bbox_inches='tight', pad_inches=0.0)
    print(f'Figure saved to {file_name}.')