import os
import pandas as pd
from openpyxl import Workbook
from openpyxl.styles import Font, Alignment
from openpyxl.utils import get_column_letter
from fixed_variables import *


def get_mmq(model, modality, question):
    csv_file = os.path.join("comparison", model, f"{modality}.csv")
    df = pd.read_csv(csv_file)
    value = round(df[question].mean() * 100, 1)
    return value


def e_baseline():
    output_file = "evaluation/e_baseline.xlsx"

    wb = Workbook()
    ws = wb.active

    # Add table head
    ws.merge_cells("A1:A2")
    ws["A1"] = "Model"
    ws["A1"].alignment = Alignment(horizontal="center", vertical="center")
    ws.merge_cells(start_row=1, start_column=2, end_row=1, end_column=5)
    ws.cell(row=1, column=2, value="Real")
    ws.merge_cells(start_row=1, start_column=6, end_row=1, end_column=9)
    ws.cell(row=1, column=6, value="Synthetic")
    ws.merge_cells(start_row=1, start_column=10, end_row=1, end_column=13)
    ws.cell(row=1, column=10, value="Triple")
    ws.merge_cells(start_row=1, start_column=14, end_row=1, end_column=17)
    ws.cell(row=1, column=14, value="Empty")
    for i, modality_start in enumerate([2, 6, 10, 14]):
        for j, q_short in enumerate(['ER', 'RU', 'KG', 'VR']):
            col = modality_start + j
            ws.cell(row=2, column=col, value=q_short)

    # Add data
    # full_model_list = [
    #     "Qwen/Qwen2.5-VL-3B-Instruct",
    #     "Qwen/Qwen2.5-VL-7B-Instruct",
    #     "Qwen/Qwen2.5-VL-32B-Instruct",
    #     "Qwen/Qwen2.5-VL-72B-Instruct",
    #     "llava-hf/llava-v1.6-vicuna-7b-hf",
    #     "llava-hf/llava-v1.6-vicuna-13b-hf",
    #     "llava-hf/llava-v1.6-34b-hf",
    # ]
    for row_idx, model in enumerate(full_model_list, start=3):
        ws.cell(row=row_idx, column=1, value=model_map[model])
        col_idx = 2
        for modality in ["real", "synthetic", "triple", "empty"]:
            if not os.path.exists(os.path.join("comparison", model.split("/")[-1], f"{modality}.csv")):
                continue
            for question in questions:
                val = get_mmq(model.split("/")[-1], modality, question)
                cell = ws.cell(row=row_idx, column=col_idx, value=val)
                cell.number_format = "0.0"
                col_idx += 1

    # Change format
    for row in ws.iter_rows():
        for cell in row:
            cell.font = Font(name='Times New Roman')
            cell.alignment = Alignment(horizontal="center", vertical="center")
    for col_idx, col in enumerate(ws.columns, start=1):
        col_letter = get_column_letter(col_idx)
        if col_idx == 1:
            ws.column_dimensions[col_letter].width = 25
        else:
            ws.column_dimensions[col_letter].width = 6

    wb.save(output_file)


if __name__ == '__main__':
    e_baseline()


