import json
from collections import defaultdict
from openai import OpenAI
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import interp1d
from tslearn.metrics import dtw
import joblib
import os
import pandas as pd
from datetime import datetime
from sklearn_extra.cluster import KMedoids
from tslearn.clustering import TimeSeriesKMeans
from tslearn.preprocessing import TimeSeriesScalerMeanVariance
# or ground truth
permutation_list2 = [
    ['A','B','C','D'],
    ['A','C','B','D'],
    ['B','A','C','D'],
    ['C','A','B','D'],
    ['B','C','A','D'],
    ['C','B','A','D']
]
# retransform to number
permutation_3 = [
    {
        'A':0, 'B':1, 'C':2, 'D':3
    },
    {
        'A': 0, 'B': 2, 'C': 1, 'D': 3
    },
    {
        'A': 1, 'B': 0, 'C': 2, 'D': 3
    },
    {
        'A': 1, 'B': 2, 'C': 0, 'D': 3
    },
    {
        'A': 2, 'B': 0, 'C': 1, 'D': 3
    },
    {
        'A': 2, 'B': 1, 'C': 0, 'D': 3
    }
]

x_list = [0, 1 / 8, 1 / 7, 1 / 6, 1 / 5, 1 / 4, 2 / 7, 1 / 3, 3 / 8, 2 / 5, 3 / 7, 1 / 2, 4 / 7, 3 / 5, 5 / 8, 2 / 3, 5 / 7, 3 / 4, 5 / 6, 6 / 7, 7 / 8, 1]

def interp(original_x,original_y,interp_x_list):
    if type(original_x) == np.array:
        linear_interp = interp1d(original_x, original_y, kind='linear')
    else:
        linear_interp = interp1d(np.array(original_x), np.array(original_y), kind='linear')
    return [linear_interp(x) for x in interp_x_list]

def clustering(lines,model_name,line_type,correct_type, n_clusters=10, metric="dtw",save_results=False,save_dir="None"):
    lines_np = np.array(lines)
    model = TimeSeriesKMeans(n_clusters=n_clusters, metric=metric, random_state=0)
    labels = model.fit_predict(lines_np)
    counts = np.bincount(labels)
    centers = model.cluster_centers_  # shape = (n_clusters, length, 1)
    if save_results:
        os.makedirs(save_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        model_path = os.path.join(save_dir, f"kmeans_model_{model_name}_{line_type}_{correct_type}.pkl")
        joblib.dump(model, model_path)
        centers_2d = centers.squeeze()
        centers_path = os.path.join(save_dir, f"cluster_centers_{model_name}_{line_type}_{correct_type}.csv")
        pd.DataFrame(centers_2d).to_csv(centers_path, index=False)
        assignments_path = os.path.join(save_dir, f"cluster_assignments_{model_name}_{line_type}_{correct_type}.csv")
        pd.DataFrame({
            "sample_id": range(len(lines)),
            "cluster": labels,
            "cluster_size": [counts[label] for label in labels]
        }).to_csv(assignments_path, index=False)
        print(f"聚类分配保存至: {assignments_path}")

        stats_path = os.path.join(save_dir, f"cluster_stats_{model_name}_{line_type}_{correct_type}.csv")
        pd.DataFrame({
            "cluster": range(len(counts)),
            "count": counts,
            "percentage": counts / sum(counts) * 100
        }).to_csv(stats_path, index=False)

        metadata_path = os.path.join(save_dir, f"metadata_{model_name}_{line_type}_{correct_type}_{timestamp}.txt")
        with open(metadata_path, "w") as f:
            f.write(f"Time: {datetime.now()}\n")
            f.write(f"Clusters: {n_clusters}\n")
            f.write(f"Metric: {metric}\n")
            f.write(f"Nums: {len(lines)}\n")
            f.write(f"Length: {lines_np.shape[1]}\n")
        print(f"Metadata: {metadata_path}")

        data_path = os.path.join(save_dir, f"original_data_{model_name}_{line_type}_{correct_type}.npy")
        np.save(data_path, lines_np)
        print(f"Save to: {data_path}")
    return centers, counts

def load_cluster_stats(save_dir, model_name,line_type,correct_type):
    model = joblib.load(os.path.join(save_dir,f"kmeans_model_{model_name}_{line_type}_{correct_type}.pkl"))
    centers = pd.read_csv(os.path.join(save_dir,f"cluster_centers_{model_name}_{line_type}_{correct_type}.csv")).values
    assignments = pd.read_csv(os.path.join(save_dir,f"cluster_assignments_{model_name}_{line_type}_{correct_type}.csv"))
    stats = pd.read_csv(os.path.join(save_dir,f"cluster_stats_{model_name}_{line_type}_{correct_type}.csv"))
    counts = stats['count']
    return centers,counts


def analyse_original_data_filtered(output_data,output_data_wo_cot,line_type,filter_type):
    count = 0
    lines = []
    for sample,sample_no_cot in zip(output_data,output_data_wo_cot):
        for idx in range(len(sample['corrects'])):
            condition = False
            if filter_type == 'tn':
                condition = not sample['corrects'][idx] and sample_no_cot['corrects'][idx]
            elif filter_type == 'tp':
                condition = sample['corrects'][idx] and sample_no_cot['corrects'][idx]
            elif filter_type == 'fp':
                condition = sample['corrects'][idx] and not sample_no_cot['corrects'][idx]
            else:
                condition = not sample['corrects'][idx] and not sample_no_cot['corrects'][idx]

            if condition:

                ground_truth = 0
                ground_truth_line = {
                    'y': [],
                    'x': []
                }
                choice_line = {
                    'y': [],
                    'x': []
                }
                ground_truth = permutation_list2[idx][0]
                choice = sample['choices'][idx]
                prob_length = len(sample['history_probs'][idx]) - 1
                if prob_length < 1:
                    continue
                # print(sample['id'],' ',len(sample['history_probs']))
                for idy, cnt_probs in enumerate(sample['history_probs'][idx]):
                    # print(choice)
                    if choice != 'A' and choice != 'B' and choice != 'C' and choice != 'D':
                        continue
                    ground_truth_line['y'].append(cnt_probs[ground_truth])
                    ground_truth_line['x'].append((idy) / prob_length)
                    choice_line['y'].append(cnt_probs[choice])
                    choice_line['x'].append((idy) / prob_length)
                if len(ground_truth_line['y']) < 1:
                    continue
                if line_type == 'gt':
                    lines.append(ground_truth_line)
                elif line_type == 'choice':
                    lines.append(choice_line)
    color_map = ['blue']
    plt.xlim(0, 1.1)
    plt.ylim(0, 1.1)
    for idx, line in enumerate(lines):
        plt.plot(line['x'], line['y'], color=color_map[0], linewidth=0.1)
    # print(labels)
    plt.title(f"{filter_type} ({line_type}), total:{len(lines)}",)
    plt.legend(fontsize=6)
    plt.show()
def analyse_original_data(output_data, line_type):
    lines = []
    for sample in output_data:
        for idx in range(len(sample['corrects'])):
            ground_truth = 0
            ground_truth_line = {
                'y': [],
                'x': []
            }
            choice_line = {
                'y': [],
                'x': []
            }
            if not sample['corrects'][idx]:
                ground_truth = permutation_list2[idx][0]
                choice = sample['choices'][idx]
                prob_length = len(sample['history_probs'][idx]) - 1
                if prob_length < 1:
                    continue
                # print(sample['id'],' ',len(sample['history_probs']))
                if len(sample['history_probs'][idx]) < 5:
                    continue
                for idy, cnt_probs in enumerate(sample['history_probs'][idx]):
                    # print(choice)
                    if choice != 'A' and choice != 'B' and choice != 'C' and choice != 'D':
                        continue
                    ground_truth_line['y'].append(cnt_probs[ground_truth])
                    ground_truth_line['x'].append((idy) / prob_length)
                    choice_line['y'].append(cnt_probs[choice])
                    choice_line['x'].append((idy) / prob_length)
                if len(ground_truth_line['y']) < 1:
                    continue
                if line_type == 'gt':
                    lines.append(ground_truth_line)
                elif line_type == 'choice':
                    lines.append(choice_line)
    color_map = ['blue']
    plt.set_xlim(0, 1.1)
    plt.set_ylim(0, 1.1)
    for idx, line in enumerate(lines):
        plt.plot(line['x'], line['y'], color=color_map[0], linewidth=0.1)
    # print(labels)
    plt.legend(fontsize=6)
    plt.show()
    # print(len_set)
def analyse_rationale(data):
    for sample in data:
        for thinkings in sample['history_thinkings']:
            for thinking in thinkings:
                cut_pos = thinking.find('|')
                ori_str = thinking[:cut_pos]
                next_str = thinking[cut_pos + 1:]
                if next_str[0] != 'A' and  next_str[0] != 'B' and  next_str[0] != 'C' and next_str[0] != 'D':
                    print(ori_str,'\n',next_str)
def analyse_data(output_data,line_type,model_name,correct_type,ax):

    len_set = defaultdict(int)
    lines_cpy = []
    lines = []
    for sample in output_data:
        for idx in range(len(sample['corrects'])):
            # print(len(sample['corrects']))
            ground_truth = 0
            ground_truth_line = {
                'y': [],
                'x': []
            }
            choice_line = {
                'y': [],
                'x': []
            }
            ground_truth_line_interp = {
                'y': []
            }
            if (correct_type and sample['corrects'][idx]) or (not correct_type and not sample['corrects'][idx]):
                ground_truth = permutation_list2[idx][0]
                choice = sample['choices'][idx]
                # print(ground_truth,' | ',choice)
                prob_length = len(sample['history_probs'][idx]) - 1
                if prob_length < 1:
                    continue
                len_set[prob_length] += 1
                # print(sample['id'],' ',len(sample['history_probs']))
                for idy,cnt_probs in enumerate(sample['history_probs'][idx]):
                    # print(choice)
                    if choice != 'A' and choice != 'B' and choice != 'C' and  choice != 'D':
                        continue
                    ground_truth_line['y'].append(cnt_probs[ground_truth])
                    ground_truth_line['x'].append((idy) / prob_length)
                    choice_line['y'].append(cnt_probs[choice])
                    choice_line['x'].append((idy) / prob_length)
                # if ground_truth_line['y'][-1] > choice_line['y'][-1]:
                #     print(sample['id'], f' ({idx})', ground_truth_line['y'][-1], ' ',choice_line['y'][-1])
                # lines.append(ground_truth_line)\
                if len(ground_truth_line['y']) < 1:
                    continue
                if line_type == 'gt':
                    x_pts = np.array(ground_truth_line['x'])
                    y_pts = np.array(ground_truth_line['y'])
                    # print(f"{x_pts}, {len(x_pts)}, {y_pts}, {len(y_pts)}")
                    linear_interp = interp1d(x_pts, y_pts, kind='linear')
                    y_interp = [linear_interp(x) for x in x_list]
                    ground_truth_line_interp['y'] = y_interp
                    lines.append(ground_truth_line_interp)
                    lines_cpy.append(y_interp)
                elif line_type == 'choice':
                    x_pts = np.array(choice_line['x'])
                    y_pts = np.array(choice_line['y'])
                    linear_interp = interp1d(x_pts, y_pts, kind='linear')
                    y_interp = [linear_interp(x) for x in x_list]
                    choice_line_interp = {
                        'y': y_interp
                    }
                    lines_cpy.append(y_interp)
                    lines.append(choice_line_interp)
    color_map = 'pink' if correct_type else 'lightblue'
    centers,counts = load_cluster_stats("./clustering",model_name,line_type,correct_type)
    # centers,counts = clustering(lines_cpy,model_name,line_type,correct_type,10,"dtw",True,"./clustering")

    sorted_indices = np.argsort(counts)[::-1]
    sorted_centers = [centers[i] for i in sorted_indices]
    sorted_counts = [counts[i] for i in sorted_indices]

    ax.set_xlim(0, 1.1)
    ax.set_ylim(0, 1.1)

    print("LEN:", len(lines))
    # for idx, line in enumerate(lines):
    #     ax.plot(x_list, line['y'], color=color_map, linewidth=0.04)
    blue_colors = [
        "#001f3f",
        "#003366",
        "#004c8c",
        "#0066b3",
        "#0077cc",
        "#0088e5",
        "#0099ff",
        "#00aaff",
        "#55bbff",
        "#aaddff"
    ]
    red_colors = [
        "#8B0000",
        "#B22222",
        "#CD5C5C",
        "#DC143C",
        "#E32636",
        "#FF4D4D",
        "#FF6B6B",
        "#FF8C8C",
        "#FFA8A8",
        "#FFD0D0"
    ]

    green_colors = [
        "#004d00",
        "#006633",
        "#00802b",
        "#009900",
        "#00cc44",
        "#00ff7f",
        "#66ff99",
        "#99ff99",
        "#ccffcc",
        "#e6ffe6"
    ]
    if correct_type == True:
        color_map = red_colors
    else:
        color_map = blue_colors

    for idx, (center, count) in enumerate(zip(sorted_centers, sorted_counts)):
        color_idx = min(idx, len(color_map) - 1)
        color = color_map[color_idx]
        if model_name == 'Kimi-VL-A3B-thinking':
            if correct_type:
                if idx == 0:
                    color = "#004d00"
                elif idx == 3:
                    color = "#006633"
                elif idx == 6:
                    color = "#00802b"
                elif idx == 7:
                    color = "#009900"
                elif idx == 8:
                    color = "#00cc44"
                elif idx == 9:
                    color = "#00ff7f"
            elif not correct_type:
                if idx == 7:
                    color = "#009900"
                elif idx == 8:
                    color = "#00cc44"
                elif idx == 9:
                    color = "#00802b"
        linewidth = count / 400

        ax.plot(x_list, center,
                 color=color,
                 linewidth=linewidth,
                 label=f'{count}')

    ax.legend(fontsize=2, loc='upper right', frameon=True, framealpha=0.8)
    ax.grid(True, linestyle='--', alpha=0.5)
    # print(labels)
    if line_type == 'gt':
        label_text = "Probability of GT"
    elif line_type == 'choice':
        label_text = "Probability of Final Choice"
    title_text = model_name + " "
    if correct_type:
        title_text += "(Successful Case)"
    else:
        title_text += "(Error Case)"
    ax.tick_params(axis='both', which='major', labelsize=0)
    # ax.set_ylabel(label_text, fontsize=8, fontweight='bold')
    # ax.set_xlabel("Thinking process", fontsize=8, fontweight='bold')
    ax.set_title(title_text, fontsize=8, fontweight='bold')
    ax.legend(fontsize=5)
    print(len_set)

from matplotlib.gridspec import GridSpec
def create_combined_plot(configs):

    fig, axes = plt.subplots(2, 3, figsize=(8, 4))

    axes = axes.flatten()

    for i, params in enumerate(config):
        if i >= len(axes):
            break
        analyse_data(*params, ax=axes[i])
    # plt.title("Ground Truth ")
    for i, ax in enumerate(axes):
        ax.set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
        ax.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
        ax.tick_params(axis='both', which='major', labelsize=8)

    axes[0].set_ylabel('Probability of GT', fontsize=8, fontweight='bold',labelpad=0)
    axes[3].set_ylabel('Probability of GT', fontsize=8, fontweight='bold',labelpad=0)

    axes[3].set_xlabel('Thinking process', fontsize=8, fontweight='bold',labelpad=0)
    axes[4].set_xlabel('Thinking process', fontsize=8, fontweight='bold',labelpad=0)
    axes[5].set_xlabel('Thinking process', fontsize=8, fontweight='bold',labelpad=0)

    for i, ax in enumerate(axes):
        if i not in [0, 3]:
            ax.set_ylabel('')
        if i not in [3, 4, 5]:
            ax.set_xlabel('')
    plt.tight_layout()
    plt.tight_layout()
    plt.savefig("appendix_probe.pdf", bbox_inches='tight')
    plt.show()


def filter_error_data(output_file):
    remove_count = 0
    new_json_answer = []
    with open(output_file, 'r', encoding='utf-8') as f:
        json_answer = json.load(f)
    for i in json_answer:
        responses = []
        scores = []
        corrects = []
        choices = []
        history_probs = []
        history_thinkings = []
        for idx in range(6):
            can = True
            for k in i['history_thinkings'][idx]:
                split_index = k.find('|')
                if split_index == -1:
                    remove_count += 1
                    can = False
                    break
                else:
                    if k[split_index+1] == ' ':
                        if k[split_index+2] != 'A' and k[split_index+2] != 'B' and k[split_index+2] != 'C' and k[split_index+2] != 'D':
                            remove_count += 1
                            can = False
                            break
                    else:
                        if k[split_index+1] != 'A' and k[split_index+1] != 'B' and k[split_index+1] != 'C' and k[split_index+1] != 'D':
                            remove_count += 1
                            can = False
                            break
            if can:
                responses.append(i['responses'][idx])
                scores.append(i['scores'][idx])
                corrects.append(i['corrects'][idx])
                choices.append(i['choices'][idx])
                history_probs.append(i['history_probs'][idx])
                history_thinkings.append(i['history_thinkings'][idx])
        new_json_answer.append(
            {
                'id': i['id'],
                'image_path': i['image_path'],
                'question': i['question'],
                # 'item': i['item'],
                'category': i['category'],
                'options': i['options'],
                'ground_truth': i['ground_truth'],
                'responses': responses,
                'choices': choices,
                'scores': scores,
                'corrects': corrects,
                'history_probs': history_probs,
                'history_thinkings': history_thinkings
            }
        )
    print(remove_count)
    return new_json_answer


def analyse_true_negative(output_data_wo_cot,output_data):
    analyse_original_data_filtered(output_data,output_data_wo_cot,'gt','tn')
    analyse_original_data_filtered(output_data,output_data_wo_cot,'choice','tn')

    analyse_original_data_filtered(output_data,output_data_wo_cot,'gt','tp')
    analyse_original_data_filtered(output_data,output_data_wo_cot,'choice','tp')

    analyse_original_data_filtered(output_data,output_data_wo_cot,'gt','fp')
    analyse_original_data_filtered(output_data,output_data_wo_cot,'choice','fp')

    analyse_original_data_filtered(output_data,output_data_wo_cot,'gt','fn')
    analyse_original_data_filtered(output_data,output_data_wo_cot,'choice','fn')


def different(output_data):
    count = 0
    for sample in output_data:
        for idx in range(6):
            if len(sample['history_probs'][idx]) < 1:
                continue
            last_probs_dict = sample['history_probs'][idx][-1]
            choice_from_probs = max(last_probs_dict,key=last_probs_dict.get)
            if sample['choices'][idx] != choice_from_probs:
                count += 1
                print(f"{sample['id']}({idx}):")
    print(f"TOTAL: {count}")

def find_shift(output_data):
    count = 0
    increase_count = 0
    decrease_count = 0
    fck_count = 0
    for sample in output_data:
        for idx in range(6):
            cnt_increase_count = 0
            cnt_decrease_count = 0
            cnt = sample['history_probs'][idx][0][permutation_list2[idx][0]]
            for idy,history_prob in enumerate(sample['history_probs'][idx]):
                if abs(history_prob[permutation_list2[idx][0]] - cnt) > 0.2 and idx < len(sample['history_probs'][idx]) / 2:

                    if history_prob[permutation_list2[idx][0]] > cnt:
                        increase_count += 1
                        cnt_increase_count += 1
                    else:
                        decrease_count += 1
                        cnt_decrease_count += 1

                    print(f"""
##########################
{sample['id']}({idx}) : {sample['question']} 
Ground truth: {sample['ground_truth']}

current thinking: {sample['history_thinkings'][idx][idy]}
BEFORE:{cnt} -> AFTER:{history_prob[permutation_list2[idx][0]]}
##########################
                    """)
                cnt = history_prob[permutation_list2[idx][0]]
            if cnt_increase_count > 0 and cnt_decrease_count > 0:
                  fck_count += 1
    print(f"TOTAL: INCREASE:{increase_count} , DECREASE:{decrease_count}\n BLOW_COUNT:{fck_count}")

def repair(output_data):
    for sample in output_data:
        for idx in range(6):
            # print(sample['id'])
            if len(sample['history_probs'][idx]) < 1:
                continue
            last_probs_dict = sample['history_probs'][idx][-1]
            choice_from_probs = max(last_probs_dict,key=last_probs_dict.get)
            sample['choices'][idx] = choice_from_probs
            if sample['choices'][idx] == permutation_list2[idx][0]:
                sample['corrects'][idx] = True
            else:
                sample['corrects'][idx] = False

    return output_data

if __name__ == '__main__':
    # with open("./experiment3/qwen2.5-vl-32b_probe_sum.json",'r', encoding='utf-8') as f:
    #     output_data = json.load(f)
    # with open("./experiment3/qwen2.5-vl-7b_probe3_0820.json", 'r', encoding='utf-8') as f:
    #     output_data = json.load(f)
    # with open("./final_result_woCoT/qwen2.5-vl-7b_shuffle_20250808_070932.json", 'r', encoding='utf-8') as f:
    #     output_data_no_cot = json.load(f)
    # with open("final_result_woCoT/internvl3-38b_shuffle_20250806_060239.json", 'r', encoding='utf-8') as f:
    #     output_data_no_cot = json.load(f)
    # with open("experiment3/internvl3-38b_probe3_0822.json", 'r', encoding='utf-8') as f:
    #     output_data = json.load(f)
    # with open("./final_result_woCoT/qwen2.5-vl-7b_shuffle_20250808_070932.json", 'r', encoding='utf-8') as f:
    #     output_data_no_cot = json.load(f)
    # with open("experiment3/qwen2.5-vl-7b_probe3_0822.json", 'r', encoding='utf-8') as f:
    #     output_data = json.load(f)
    # with open("final_result_woCoT/qwen2.5-vl-32b_shuffle_20250805_034238.json", 'r', encoding='utf-8') as f:
    #     output_data_no_cot = json.load(f)
    # with open("experiment3/qwen2.5-vl-32b_probe3_0822.json", 'r', encoding='utf-8') as f:
    #     output_data = json.load(f)
    # with open("final_result_woCoT/internvl3-38b_shuffle_20250806_060239.json", 'r', encoding='utf-8') as f:
    #     output_data_no_cot = json.load(f)
    # with open("experiment3/kimi-vl-thinking-a3b-2506_probe3_0828.json", 'r', encoding='utf-8') as f:
    #     output_data = json.load(f)
    # output_repair = repair(output_data)
    # analyse_data(output_repair,"choice","InternVL3-38B",True)
    # analyse_data(output_repair,"gt","InternVL3-38B",True)
    # different(output_data)
    # filtered_output_data = filter_error_data("./experiment3/qwen2.5-vl-32b_probe_sum.json")
    # print(filtered_output_data[0])
    # analyse_true_negative(output_data_no_cot,output_data)
    # find_shift(output_data, output_data_no_cot)
    # analyse_original_data(output_repair,'choice')
    # analyse_original_data(output_repair,'gt')

    # analyse_correct_data(output_data)
    # client = OpenAI(
    #     api_key="EMPTY",
    #     base_url="http://222.29.51.247:18901/v1"
    # )
    # test_model(client)

    # InternVL3 38B

    config = []
    #
    with open("experiment3/internvl3-38b_probe3_0822.json", 'r', encoding='utf-8') as f:
        output_data = json.load(f)
    output_repair1 = repair(output_data)
    with open("experiment3/internvl3-14b_probe3_0822.json", 'r', encoding='utf-8') as f:
        output_data = json.load(f)
    output_repair2 = repair(output_data)
    with open("experiment3/qwen2.5-vl-32b_probe3_0822.json", 'r', encoding='utf-8') as f:
        output_data = json.load(f)

    output_repair3 = repair(output_data)

    with open("experiment3/qwen2.5-vl-7b_probe3_0822.json", 'r', encoding='utf-8') as f:
        output_data = json.load(f)
    output_repair4 = repair(output_data)


    with open("experiment3/kimi-vl-thinking-a3b-2506_probe3_0828.json", 'r', encoding='utf-8') as f:
        output_data = json.load(f)
    output_repair5 = repair(output_data)
    # analyse_data(output_repair,"choice","Qwen2.5-VL-7B",True)
    # analyse_data(output_repair,"gt","Qwen2.5-VL-7B",True)
    # analyse_data(output_repair,"choice","Qwen2.5-VL-7B",False)
    # config += [
    #     # (output_repair, "choice", "Qwen2.5-VL-7B", True),
    #     # (output_repair, "choice", "Qwen2.5-VL-7B", False),
    #     (output_repair1, "gt", "InternVL3-38B", True),
    #     # (output_repair2, "gt", "InternVL3-14B", True),
    #     (output_repair3, "gt", "Qwen2.5-VL-32B", True),
    #     # (output_repair4, "gt", "Qwen2.5-VL-7B", True),
    #     (output_repair5, "gt", "Kimi-VL-A3B-thinking", True),
    #     (output_repair1, "gt", "InternVL3-38B", False),
    #     # (output_repair2, "gt", "InternVL3-14B", False),
    #     (output_repair3, "gt", "Qwen2.5-VL-32B", False),
    #     # (output_repair4, "gt", "Qwen2.5-VL-7B", False),
    #     (output_repair5, "gt", "Kimi-VL-A3B-thinking",False)
    # ]
    # analyse_data(output_repair,"gt","Qwen2.5-VL-7B",False)
    create_combined_plot(config)
    #
    config += [
        # (output_repair, "choice", "Kimi-VL-A3B-thinking", True),
        (output_repair1, "gt", "InternVL3-38B", True),
        (output_repair1, "choice", "InternVL3-38B", False),
        (output_repair1, "gt", "InternVL3-38B", False)
    ]

    config += [
        # (output_repair, "choice", "Kimi-VL-A3B-thinking", True),
        (output_repair2, "gt", "InternVL3-14B", True),
        (output_repair2, "choice", "InternVL3-14B", False),
        (output_repair2, "gt", "InternVL3-14B", False)
    ]

    config += [
        # (output_repair, "choice", "Kimi-VL-A3B-thinking", True),
        (output_repair3, "gt", "Qwen2.5-VL-32B", True),
        (output_repair3, "choice", "Qwen2.5-VL-32B", False),
        (output_repair3, "gt", "Qwen2.5-Vl-32B", False)
    ]

    config += [
        # (output_repair, "choice", "Kimi-VL-A3B-thinking", True),
        (output_repair4, "gt", "Qwen2.5-VL-7B", True),
        (output_repair4, "choice", "Qwen2.5-VL-7B", False),
        (output_repair4, "gt", "Qwen2.5-VL-7B", False)
    ]

    config += [
        # (output_repair, "choice", "Kimi-VL-A3B-thinking", True),
        (output_repair5, "gt", "Kimi-VL-A3B-thinking", True),
        (output_repair5, "choice", "Kimi-VL-A3B-thinking", False),
        (output_repair5, "gt", "Kimi-VL-A3B-thinking", False)
    ]
    #
    create_combined_plot(config)