import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from openpyxl import Workbook
from openpyxl.styles import Font, Alignment
from openpyxl.utils import get_column_letter
from fixed_variables import *


def get_mmq(model, modality, question):
    csv_file = os.path.join("comparison", model, f"{modality}.csv")
    df = pd.read_csv(csv_file)
    value = round(df[question].mean() * 100, 1)
    return value


def e_modality_gap():
    models = full_model_list
    output_file = "evaluation/modality_gap/e_modality_gap.xlsx"

    accuracies = np.zeros((len(models), len(modalities), len(questions)))
    for i, model in enumerate(models):
        for j, modality in enumerate(modalities):
            for k, question in enumerate(questions):
                accuracies[i][j][k] = get_mmq(model.split("/")[-1], modality, question)
    avg_accuracies = np.mean(accuracies, axis=0).T

    wb = Workbook()
    ws = wb.active
    ws.cell(row=1, column=2, value="Real")
    ws.cell(row=1, column=3, value="Synthetic")
    ws.cell(row=1, column=4, value="Triple")
    for k, question in enumerate(questions):
        ws.cell(row=k+2, column=1, value=question)
        for j, modality in enumerate(modalities):
            cell = ws.cell(row=k+2, column=j+2, value=avg_accuracies[k][j])
            cell.number_format = "0.0"

    for row in ws.iter_rows():
        for cell in row:
            cell.font = Font(name='Times New Roman')
            cell.alignment = Alignment(horizontal="center", vertical="center")
    for col_idx, col in enumerate(ws.columns, start=1):
        col_letter = get_column_letter(col_idx)
        if col_idx == 1:
            ws.column_dimensions[col_letter].width = 15
        else:
            ws.column_dimensions[col_letter].width = 10

    wb.save(output_file)


def draw_bar(models, output_file, y_min=40, legend=True):
    accuracies = np.zeros((len(models), len(modalities)))
    for i, model in enumerate(models):
        print(f"### {model_map[model]}")
        for j, modality in enumerate(modalities):
            accuracies[i][j] = sum(get_mmq(model.split("/")[-1], modality, q) for q in questions) / len(questions)
            print(f"{modality}: {accuracies[i][j]:.2f}")
        print()

    width = 0.15
    spacing = 0.8
    bar_text_offset = (100 - y_min) * 0.01
    plt.figure(figsize=(len(models) * 4, 12))
    x = np.arange(len(models)) * spacing
    for i in range(len(modalities)):
        bars = plt.bar(x + i * width, accuracies[:, i], width=width, label=modalities[i])
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width() / 2, height + bar_text_offset,
                     f"{height:.1f}", ha="center", va="bottom", fontsize=15)

    plt.xticks(x + width * (len(modalities) - 1) / 2, [model_map[m] for m in models], fontsize=22)
    plt.ylim(y_min, 100)
    plt.yticks(list(range(y_min, 101, 10)), fontsize=15)
    plt.ylabel("Average Accuracy (%)", fontsize=24)
    plt.grid(axis="y", linestyle="--", alpha=0.7)
    if legend:
        plt.legend(loc="center left", bbox_to_anchor=(1, 0.5), fontsize=24)
    plt.tight_layout()
    plt.savefig(output_file, bbox_inches="tight")
    plt.close()


def e_sama_size():
    models = [
        "Qwen/Qwen2.5-VL-7B-Instruct",
        "llava-hf/llava-v1.6-vicuna-7b-hf",
        "meta-llama/Llama-3.2-11B-Vision-Instruct",
        "google/gemma-3-12b-it",
        "mistralai/Pixtral-12B-2409",
    ]
    draw_bar(models, "evaluation/modality_gap/e_sama_size.pdf", y_min=40)


def e_scale_law():
    models_1 = [
        "Qwen/Qwen2.5-VL-3B-Instruct",
        "Qwen/Qwen2.5-VL-7B-Instruct",
        "Qwen/Qwen2.5-VL-32B-Instruct",
        "Qwen/Qwen2.5-VL-72B-Instruct",
    ]
    draw_bar(models_1, "evaluation/modality_gap/e_scale_law_qwen.pdf", y_min=80)

    models_2 = [
        "google/gemma-3-1b-it",
        "google/gemma-3-12b-it",
        "google/gemma-3-27b-it",
    ]
    draw_bar(models_2, "evaluation/modality_gap/e_scale_law_gemma.pdf", y_min=40, legend=False)

    models_3 = [
        "llava-hf/llava-v1.6-vicuna-7b-hf",
        "llava-hf/llava-v1.6-vicuna-13b-hf",
        "llava-hf/llava-v1.6-34b-hf",
    ]
    draw_bar(models_3, "evaluation/modality_gap/e_scale_law_llava.pdf", y_min=40, legend=False)


def e_all():
    draw_bar(full_model_list, "evaluation/modality_gap/e_all.pdf", y_min=0)


def e_draft():
    models = [
        "llava-hf/llava-v1.6-vicuna-7b-hf",
        "llava-hf/llava-v1.6-vicuna-13b-hf",
    ]
    # draw_bar(models, "evaluation/modality_gap/e_draft.pdf", y_min=40)


if __name__ == '__main__':
    # e_modality_gap()
    # e_sama_size()
    # e_scale_law()
    # e_all()
    e_draft()

