import matplotlib.pyplot as plt
import seaborn as sns
import csv
from collections import defaultdict
from typing import List, Tuple

import argparse


parser = argparse.ArgumentParser()
parser.add_argument("--file", type=str, required=True)
parser.add_argument("--dataset", type=str, required=True)
parser.add_argument("--model", type=str, required=True)

args = parser.parse_args()
file_path = args.file
dataset = args.dataset
model_name = args.model

# 全局变量
# model_name = "qwen"  # qwen or llama
if model_name == "qwen":
    skip_model_size_list = ["8B","70B"]
elif model_name == "llama":
    skip_model_size_list = ["7B","72B", "14B","32B"]
else:
    raise ValueError(f"Unknown model name: {model_name}")

def process_csv(file_path: str) -> defaultdict:
    """
    处理 CSV 文件，将每个 model_size 对应的数据按照 replace_count 排序并以 tuple 存储。
    """
    model_data = defaultdict(list)
    with open(file_path, mode='r', newline='', encoding='utf-8') as file:
        reader = csv.DictReader(file)
        for row in reader:
            model_size = row['model_size']
            if model_size in skip_model_size_list:
                continue
            replace_count = int(row['replace_count'])
            jailbreak_success_rate = float(row['jailbreak_success_rate_%'])
            model_data[model_size].append((replace_count, jailbreak_success_rate))

    for model_size in model_data:
        model_data[model_size].sort(key=lambda x: x[0])

    return model_data

# file_path = './parsed_results.csv'
model_data = process_csv(file_path)

# 获取数据长度
data_len = len(model_data.items())
zipped_data = list(model_data.items())
swapped_zipped_data = [(b, a) for a, b in zipped_data]

# 使用 seaborn 的主题风格
sns.set(style="darkgrid", rc={"axes.facecolor": "#f0f0f0", "grid.color": "gray"})

def unpack_points(points):
    xs = [p[0] for p in points]
    ys = [p[1] for p in points]
    return xs, ys

# 选择对比度较强的颜色方案
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']

# 绘制模型数据
for idx, (data, label) in enumerate(swapped_zipped_data):
    xs, ys = unpack_points(data)
    plt.plot(xs, ys, marker="o", markersize=6, linewidth=2, label=label, color=colors[idx % len(colors)])

# 设置坐标轴范围
plt.xlim(0, 26)
plt.ylim(0, 100)

# 坐标轴和标题
plt.xlabel("Perturbation Level", fontsize=12)
plt.ylabel("Success Rate (%)", fontsize=12)

# 网格与图例
plt.grid(True, linestyle='-', color='gray', alpha=0.7)
plt.axhline(y=70, color='red', linestyle='-', linewidth=2, label='70%')

# 优化图例，减少冗余
if model_name == "qwen":
    plt.legend(title=f"Qwen on {dataset}", fontsize=9, loc="upper right")
elif model_name == "llama":
    plt.legend(title=f"Llama on {dataset}", fontsize=9, loc="upper right")

# 显示图像
plt.tight_layout()
# plt.show()

# 保存图像
plt.savefig(f"/home/wind/Desktop/新工作/2024_8_18/figures/iclr/figure_{model_name}_{dataset}.pdf", format="pdf", bbox_inches="tight")
