import numpy as np

import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from matplotlib.pyplot import MultipleLocator
from matplotlib import rc

rc('font', **{'family': 'serif', 'sans-serif': ['cm']})
rc('axes', **{'titlesize': 'large'})
 
class DrawSubjectBatch:
    x_labels = ['1', '2', '4', '8']
    methods = ['FT', 'ROME', 'MEMIT', 'DiKE']

    gpt2xl_es_data = [
        [98.9, 97.9, 94.5, 93.1],
        [93.0, 62.0, 47.2, 39.5],
        [92.6, 56.1, 35.2, 29.6],
        [95.5, 96.0, 96.2, 95.6]
    ]
    gpt2xl_rs_data = [
        [49.0, 47.5, 48.2, 49.7],
        [51.7, 43.8, 40.0, 38.3],
        [51.2, 49.9, 49.7, 50.8],
        [58.1, 58.9, 58.9, 58.3]
    ]

    gptj_es_data = [
        [100.0, 100.0, 100.0, 100.0],
        [100.0, 69.6, 57.7, 57.2],
        [100.0, 95.2, 89.5, 85.2],
        [99.4, 99.5, 99.8, 99.7]
    ]
    gptj_rs_data = [
        [72.9, 68.1, 65.0, 62.5],
        [55.7, 52.5, 49.9, 51.2],
        [62.7, 66.4, 67.4, 69.2],
        [71.8, 73.6, 69.9, 68.1]
    ]

    llama_es_data = [
        [100.0, 100.0, 100.0, 100.0],
        [99.9, 60.5, 35.9, 26.4],
        [99.3, 76.6, 63.4, 62.7],
        [97.5, 97.7, 97.4, 96.9]
    ]
    llama_rs_data = [
        [67.7, 65.3, 61.5, 60.54],
        [58.5, 46.2, 19.8, 10.5],
        [65.2, 68.4, 70.6, 71.8],
        [72.0, 73.7, 73.6, 73.2]
    ]

    def draw_demo(self):
        x = np.arange(len(self.x_labels))

        plt.figure(figsize=(5, 5)) # 设置图像大小
        plt.grid(linestyle="--")  # 设置背景网格线为虚线
        ax = plt.gca() #gca就是get current axes 获取当前坐标轴
        ax.spines['top'].set_visible(False)  # 去掉上边框
        ax.spines['right'].set_visible(False)  # 去掉右边框

        # label在图示(legend)中显示。若为数学公式,则最好在字符串前后添加"$"符号
        # color：b:blue、g:green、r:red、c:cyan、m:magenta、y:yellow、k:black、w:white、、、
        # 线型：-  --   -.  :    ,
        # marker：.  ,   o   v    <    *    +    1

        for idx, d in enumerate(self.gpt2xl_es_data):
            plt.plot(x, d, marker='o', label=self.methods[idx], linewidth=1.5)

        plt.xticks(x, self.x_labels, fontsize=12, fontweight='bold')  # 默认字体大小为10
        plt.yticks(fontsize=12, fontweight='bold')
        plt.title("Example", fontsize=13, fontweight='bold')  # 默认字体大小为12
        plt.xlabel("Edit Batch Size", fontsize=13, fontweight='bold')
        plt.ylabel("Score", fontsize=13, fontweight='bold')
        # plt.xlim(0.9, 6.1)  # 设置x轴的范围
        # plt.ylim(10, 22)

        plt.legend(loc=0, numpoints=1)
        leg = plt.gca().get_legend()
        ltext = leg.get_texts()
        plt.setp(ltext, fontsize=12, fontweight='bold')  # 设置图例字体的大小和粗细
        
        plt.savefig('./filename.png', format='png')  # 建议保存为svg格式,再用在线转换工具转为矢量图emf后插入word中
        plt.close()
    
    def draw_subplot(self, data):
        x = np.arange(len(self.x_labels))

        # plt.grid(linestyle="--")  # 设置背景网格线为虚线
        ax = plt.gca() #gca就是get current axes 获取当前坐标轴
        # ax.spines['top'].set_visible(False)  # 去掉上边框
        # ax.spines['right'].set_visible(False)  # 去掉右边框

        for idx, d in enumerate(data):
            marker, fillstyle = ('o', 'full') if self.methods[idx] == 'DiKE' else ('^', 'none')
            plt.plot(x, d, linestyle='-', marker=marker, label=self.methods[idx], fillstyle=fillstyle, linewidth=2, markersize=8)

            # marker1 = ax1.plot(x, recall[:, 1], linestyle='-', marker='^', fillstyle='none', linewidth=1, color=colors[1], alpha=0.8)
            # marker2 = ax2.plot(x, ndcg[:, 1], linestyle='-', marker='o', fillstyle='none', linewidth=1, color=colors[2], alpha=0.8)

        plt.xticks(x, self.x_labels, fontsize=12)  # 默认字体大小为10
        plt.yticks(fontsize=12)
        # plt.xlim(0.9, 6.1)  # 设置x轴的范围
        # plt.ylim(0, 100)

    def draw1(self):
        font_size = 17
        fig = plt.figure(figsize=(15, 7))
        fig.subplots_adjust(hspace=0.2, wspace=0.4)

        plt.subplot(2, 3, 1)
        self.draw_subplot(self.gpt2xl_es_data)
        plt.ylabel("Efficacy", fontsize=font_size)
        plt.title("GPT2-XL", fontsize=font_size, y=1)

        plt.subplot(2, 3, 2)
        self.draw_subplot(self.gptj_es_data)
        plt.title("GPT-J", fontsize=font_size, y=1)

        plt.subplot(2, 3, 3)
        self.draw_subplot(self.llama_es_data)
        plt.title("LLaMA3", fontsize=font_size, y=1)

        plt.subplot(2, 3, 4)
        self.draw_subplot(self.gpt2xl_rs_data)
        plt.xlabel("Edit Batch Size", fontsize=font_size)
        plt.ylabel("Relational Locality", fontsize=font_size)

        plt.subplot(2, 3, 5)
        self.draw_subplot(self.gptj_rs_data)
        plt.xlabel("Edit Batch Size", fontsize=font_size)

        plt.subplot(2, 3, 6)
        self.draw_subplot(self.llama_rs_data)
        plt.xlabel("Edit Batch Size", fontsize=font_size)

        plt.figure(fig)
        plt.legend(bbox_to_anchor=(0.88, 0.7), prop={'size': font_size})
        leg = plt.gca().get_legend()
        ltext = leg.get_texts()
        plt.setp(ltext, fontsize=15)  # 设置图例字体的大小和粗细

        plt.subplots_adjust(left=0.1, right=0.9, top=0.95, bottom=0.09)

        plt.savefig('./subject_batch_edit.png', format='png')

    def draw2(self):
        font_size = 15
        fig = plt.figure(figsize=(7, 8))
        fig.subplots_adjust(hspace=0.25, wspace=0.2)

        plt.subplot(3, 2, 1)
        self.draw_subplot(self.gpt2xl_es_data)
        plt.title("Efficacy (%)", fontsize=font_size, y=1)
        plt.ylabel("GPT2-XL", fontsize=font_size)

        plt.subplot(3, 2, 2)
        self.draw_subplot(self.gptj_es_data)
        plt.title("Relational Locality (%)", fontsize=font_size, y=1)

        plt.subplot(3, 2, 3)
        self.draw_subplot(self.llama_es_data)
        plt.ylabel("GPT-J", fontsize=font_size)

        plt.subplot(3, 2, 4)
        self.draw_subplot(self.gpt2xl_rs_data)

        plt.subplot(3, 2, 5)
        self.draw_subplot(self.gptj_rs_data)
        plt.ylabel("LLaMA3", fontsize=font_size)
        plt.xlabel("Edit Batch Size", fontsize=font_size)

        plt.subplot(3, 2, 6)
        self.draw_subplot(self.llama_rs_data)
        plt.xlabel("Edit Batch Size", fontsize=font_size)

        plt.figure(1)
        plt.legend(bbox_to_anchor=(1.15, 4.0), prop={'size': font_size}, ncol=4)
        leg = plt.gca().get_legend()
        ltext = leg.get_texts()
        plt.setp(ltext, fontsize=15)  # 设置图例字体的大小和粗细

        # plt.subplots_adjust(left=0.1, right=0.9, top=0.89, bottom=0.09)

        plt.savefig('./subject_batch_edit.pdf', format='pdf')


run = DrawSubjectBatch()
run.draw2()
