import matplotlib.pyplot as plt
import numpy as np


def plot_radar():
    plt.rcParams.update({'font.size': 15})  # 设置默认字体大小为12
    # Data and categories
    data = [
        [45.40, 48.46, 77.35, 76.73, 80.03, 70.17, 78.50, 47.01],
        [45.20, 49.08, 77.36, 76.98, 80.20, 69.93, 78.20, 46.16],
        [45.60, 48.31, 77.29, 76.81, 80.09, 69.93, 77.89, 47.44],
        [44.60, 48.31, 77.34, 76.85, 79.98, 70.56, 77.80, 46.84]
    ]

    # data = [
    #     [45.00, 48.46, 77.59, 77.02, 79.89, 70.01, 77.89, 45.05],
    #     [45.20, 47.54, 77.22, 76.64, 80.20, 70.32, 75.44, 47.53],
    #     [43.80, 46.32, 77.96, 75.97, 79.16, 70.32, 75.32, 46.16],
    # ]

    categories = ['openbookqa', 'siqa', 'hellaswag', 'arc_easy', 'piqa', 'winogrande', 'boolq', 'arc_challenge']

    min_values = [44, 46, 76, 75, 79, 69, 75, 44]
    max_values = [48, 50, 78, 78, 81, 71, 79, 48]

    # legend_names = ['r=8,0,0', 'r=0,8,0', 'r=0,0,8']
    legend_names = ['r_1', 'r_2', '1_r', '2_r']
    num_vars = len(categories)

    # Normalize the data to a [0, 1] range based on custom min/max values
    normalized_data = (np.array(data) - min_values) / (np.array(max_values) - np.array(min_values))

    # Create a radar chart
    fig, ax = plt.subplots(figsize=(8, 6), subplot_kw=dict(polar=True))

    # Compute angle of each axis
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
    angles += angles[:1]  # Complete the loop

    # Add data for each row
    for i, row in enumerate(normalized_data):
        row_data = np.append(row, row[0])  # Append first value to end to close the loop
        ax.plot(angles, row_data, linewidth=2, linestyle='solid', label=legend_names[i])
        ax.fill(angles, row_data, alpha=0.25)

    # Add category labels
    ax.set_yticklabels([])  # Remove radial labels
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(categories, fontsize=18)

    # Adjust label distance from the center
    ax.tick_params(pad=10)  # Increase the distance of axis labels from the plot

    # Add custom legend
    plt.legend(loc='upper right', bbox_to_anchor=(1.25, 1.15))

    # 导出为 SVG
    plt.savefig('radar.svg', format='svg')

    # Show the plot
    plt.show()


def plot_bar():
    plt.rcParams.update({'font.size': 25})  # 设置默认字体大小为12

    # Data for the groups
    categories = ['SS', 'GT', 'KE']
    groups = ["△Wq", "△Wk", "△Wv", "△Wup", "△Wdown"]
    data = [
        [22, 4, 21.1875],
        [22, 4, 21.09375],
        [22, 4, 21.21875],
        [22, 4, 21.4375],
        [22, 4, 21.34375]
    ]

    colors = ["#DEECF9", "#F7CCAD", "#A3D4D5", "#D2E1D4", "#9DC3E6"]

    # Bar width
    bar_width = 0.18

    # The x positions for the bars
    index = np.arange(len(categories))

    # Create figure and axes
    fig, ax = plt.subplots(figsize=(12, 8))

    for i in range(len(data)):
        bar = ax.bar(index + (i + 1) * bar_width, data[i], bar_width, label=groups[i], color=colors[i],
                     edgecolor='black', linewidth=1)

    # Add labels and title
    ax.set_xlabel('Shape Transformation Methods', fontsize=25, labelpad=10)
    ax.set_ylabel('Rank Values', fontsize=25)
    ax.set_xticks(index + bar_width * 3)
    ax.set_xticklabels(categories, rotation=0, fontsize=20)
    # 调整 x 轴标签和坐标轴的距离
    ax.tick_params(axis='x', which='major', pad=10)  # pad 是标签和轴线之间的距离

    # Add a legend
    ax.legend()

    # 导出为 SVG
    plt.savefig('bar.svg', format='svg')

    # Display the chart
    plt.show()


def plot_memory():
    plt.rcParams.update({'font.size': 15})  # 设置默认字体大小为12

    # Simple data for the line chart
    x = range(300, 4200, 300)
    lora = [
        28964.150390625,
        45014.150390625,
        61064.150390625,
        77114.150390625,
        81920.000000000,
        81920.000000000,
        81920.000000000,
        81920.000000000,
        81920.000000000,
        81920.000000000,
        81920.000000000,
        81920.000000000,
        81920.000000000,
    ]

    # 81920.000000000,

    share = [
        18105.853515625,
        23337.103515625,
        28568.353515625,
        33799.603515625,
        39030.853515625,
        44262.103515625,
        49493.353515625,
        54724.603515625,
        59955.853515625,
        65187.103515625,
        70418.353515625,
        75649.603515625,
        81920.000000000,
    ]

    # Plot the line chart
    plt.figure(figsize=(8, 5))
    # oom
    oom = range(1500, 4200, 300)
    plt.scatter(oom, [81920] * len(oom), color='red', marker='x', s=200, label='Out Of Memory',
                linewidths=2.5, zorder=2)

    plt.plot(x, lora, marker='o', label='LoRA', linestyle='-', color='blue', zorder=1)
    plt.plot(x, share, marker='s', label='Bi-Share LoRA', linestyle='--', color='orange', zorder=1)

    plt.xlabel("Number of Serving LoRA")
    plt.ylabel("Memory (MB)")
    plt.title("Memory Usage In Multi-lora Serving")
    plt.legend()
    plt.grid(alpha=0.5)
    plt.ylim(0, 83920)  # Set the y-axis limit to 81920
    plt.yticks(range(0, 83921, 10000))  # Set y-axis ticks with an interval of 10000
    plt.tight_layout()

    # 导出为 SVG
    plt.savefig('serving.pdf', format='pdf')

    # Display the plot
    plt.show()


def plot_speed():
    plt.rcParams.update({'font.size': 25})  # 设置默认字体大小为12

    # Data for the groups
    categories = ['LoRA', 'VeRA', 'VB-LoRA', 'Bi-Share LoRA']
    data = [25, 30, 40, 15]

    # Plot the vertical bar chart
    plt.figure(figsize=(12, 8))
    plt.bar(categories, data, color='lightcoral', edgecolor='black', linewidth=2)
    plt.xlabel("Categories")
    plt.ylabel("Time Cost")
    plt.title("Simple Vertical Bar Chart")
    plt.grid(linestyle='--', alpha=0.7)

    # plt.grid()
    plt.tight_layout()

    # 导出为 SVG
    plt.savefig('speed.svg', format='svg')

    # Display the plot
    plt.show()


# plot_bar()
# plot_radar()
# plot_speed()
plot_memory()
