import matplotlib.pyplot as plt
import numpy as np

# 假设的数据，模型名称和对应的准确率与模型大小
models = ['Ours', 'BRECQ', 'Unified', '2 bit', '4bit', 'ResNet-18', 'MobileNetV2']
accuracies = [72.5, 70.2, 70.52, 71.39, 69.94, 71.60, 70]  # 准确率
model_sizes = [3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 0.8]  # 模型大小（Mb）

# 创建一个画布和轴对象
fig, ax1 = plt.subplots()

# 绘制准确率的条形图
def plot_bars(ax, x, y, label, color):
    ax.bar(x, y, label=label, color=color, alpha=0.7)
    ax.set_ylabel('Accuracy (%)')
    ax.set_xticks(x)
    ax.set_xticklabels(models, rotation=45, ha="right")

plot_bars(ax1, np.arange(len(models)), accuracies, 'Accuracy', 'b')

# 创建一个共享x轴的第二个轴用于绘制模型大小
ax2 = ax1.twinx()
plot_bars(ax2, np.arange(len(models)), model_sizes, 'Model Size (Mb)', 'r')

# 添加图例
fig.legend(loc='upper left', bbox_to_anchor=(0.1, 0.9))

# 显示图表
plt.show()

plt.savefig('/home/admin1/Syh/Training-free-quant/mixed_bit/picture/model_comparison.png', dpi=300)  # 可以调整DPI以改变图像质量

# 记得在保存图像后关闭图表，释放资源
plt.close(fig)

print("end")