import os
import pandas as pd
import numpy as np
import argparse
import json
import matplotlib.pyplot as plt
from plotnine import ggplot, geom_point, aes, theme, element_text
import pandas as pd

class Plotter:
    def __init__(self, args):
        self.input_data = args.input_data
        #read json file
        with open(self.input_data, 'r') as f:
            self.df = json.load(f)
            
        self.output_img_path = args.output_img_path

    def plot(self):
        # 创建ggplot对象
        t = [(i['action'], i['predicates']) for i in self.df]
        
        # 将配对列表转换为DataFrame
        df_t = pd.DataFrame(t, columns=['action', 'predicates'])
        
        # 创建ggplot对象
        plot = (
            ggplot(df_t, aes(x='action', y='predicates'))
            + geom_point(color='blue', alpha=0.2)
            # + ggtitle('Actions and Predicates Distribution Scatter Plot')
            + theme(axis_title_x=None, axis_title_y=None)
        )
        plot.save(self.output_img_path)
       
# 假设df是一个包含'action'和'predicates'列的DataFrame
# output_img_path是输出图像的路径

def args():
    parser = argparse.ArgumentParser(description='Plotting difficulty')
    parser.add_argument('--input_data', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/ipc_bench/count_action.json', help='Path to data')
    parser.add_argument('--output_img_path', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/plot/image/diffucity.png', help='Path to output image')
    return parser.parse_args()
    
    
class Diff_plot:
    def __init__(self, args):
        self.input_data = args.input_data
        #read json file
        with open(self.input_data, 'r') as f:
            self.df = json.load(f)
            
        self.output_img_path = args.output_img_path
    def plot(self):
        # x = [ (itm['action'], itm['predicates']) for itm in self.df ]
        # actions_n is the x axis value, predicates_n is the y axis value, plot the (action_n, predicates_n) in pairs
        
        actions = [ i['action'] for i in self.df ]
        predicates = [ i['predicates'] for i in self.df ]
        plt.scatter(actions, predicates, color='red', alpha=0.2)
        plt.title('Actions and Predicates Distribution Scatter Plot')
        plt.xlabel('Number of actions')
        plt.ylabel('Number of predicates')
        plt.savefig(self.output_img_path)
        plt.show()
        
       
if __name__ == '__main__':
    args = args()
    # plot = Diff_plot(args)
    # plot.plot()
    plotter = Plotter(args)
    plotter.plot()  