import pandas as pd
from plotnine import ggplot, geom_line, aes, theme, element_text
import argparse
import re
import matplotlib.pyplot as plt
import pdb
def args():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('-input_csv1', default='/lustre/fast/fast/txiao/zly/spatial_head/cot/plot/table/nl.csv', type=str, help='')
    parser.add_argument('-input_csv2', default='/lustre/fast/fast/txiao/zly/spatial_head/cot/plot/table/prob.csv', type=str, help='')
    parser.add_argument('-output_png', default='/lustre/fast/fast/txiao/zly/spatial_head/cot/plot/table/prob_output_png', type=str, help='')
    return parser.parse_args()

class Plotter:
    def __init__(self, args):
        self.input_csv1 = args.input_csv1
        self.input_csv2 = args.input_csv2
        self.output_png = args.output_png
        self.data_nl = pd.read_csv(self.input_csv1)
        self.data_prob = pd.read_csv(self.input_csv2)
    def match(self,str):
        match = re.search(r'\d+', str)
        if match:
            number = match.group(0)
        else:
            number = None

        number = int(number) if number else None
        return number

    def plot(self):
        nl = self.data_nl
        nl['thought_number'] = nl['model'].apply(lambda x: int(re.search(r'\d+', x).group()))
        nl = nl.sort_values(by='thought_number')

        prob = self.data_prob
        prob['thought_number'] = prob['model'].apply(lambda x: int(re.search(r'\d+', x).group()))
        prob = prob.sort_values(by='thought_number')

        plt.figure(figsize=(10, 6))
        plt.plot(nl['thought_number'], nl['right'], marker='o', linestyle='-', color='b')
        plt.plot(prob['thought_number'], prob['right'], marker='o', linestyle='-', color='g')
        plt.axhline(332,linestyle='--', color='g' )
        plt.axhline(283,linestyle='--', color='b' )
        plt.title('')
        plt.xlabel('Thought Number')
        plt.ylabel('Pass Number')
        plt.grid(True)
        plt.savefig(self.output_png)

        
if __name__ == '__main__':
    args = args()
    v = Plotter(args)
    # result = v.get_dir_number()

    v.plot()
   