import pandas as pd
import openpyxl
import os
import copy

model='shikra_7B'
# categories=("privacy" "bias" "toxicity" "hallucination" "legality")
validate_dataset='POPE'
dataset='coco'
subfix='_train_I+Q;C_p2+Q_best_end31'
c='adversarial'
head = 40
alpha = 7

loc_dict = {}

def collect_mllmguard_all():
    output_file = '1.xlsx'
    wb = openpyxl.Workbook()
    wb.save(output_file)
    dir_path = f'/data/multimodal_alignment/mm_iti/summaries/{validate_dataset}/'
    x_values = [8, 16, 24]
    # x_values = [16, 32, 48, 64, 80, 96]
    y_values = [1, 3, 5, 7, 9]

    wb = openpyxl.load_workbook(output_file)
    ws = wb.active
    for x in x_values:
        for y in y_values:
            csv_filename = f'{model}_{x}_{y}_{subfix}.csv'
            csv_filename = dir_path + csv_filename
            if os.path.exists(csv_filename):
                with open(csv_filename, 'r') as file:
                    lines = file.readlines()
                    total_acc = 0
                    total_par = 0
                    row_offset = 1
                    # each dimension
                    for i, line in enumerate(lines[1:]):
                        values = line.strip().split(',')
                        dimension = values[1]
                        ws.cell(row=row_offset, column=1, value=dimension)
                        
                        acc = float(values[2])
                        par = float(values[3])
                        total_acc += acc
                        total_par += par
                        formatted_values = "{:.1f}/{:.1f}".format(
                            acc * 100,
                            par * 100
                        )
                        row = x // 8 + row_offset
                        column = (y+1) // 2
                        ws.cell(row=row, column=column, value=formatted_values)
                        row_offset += 8
                    
                    # average
                    ws.cell(row=row_offset, column=1, value='average')
                    formatted_values = "{:.1f}/{:.1f}".format(
                        (total_acc / len(lines[1:])) * 100,
                        (total_par / len(lines[1:])) * 100
                    )
                    row = x // 8 + row_offset
                    column = (y+1) // 2
                    ws.cell(row=row, column=column, value=formatted_values)
                    
    wb.save(output_file)
    
def collect_mmsafetybench_all():
    output_file = '1.xlsx'
    wb = openpyxl.Workbook()
    wb.save(output_file)
    dir_path = f'/data/multimodal_alignment/mm_iti/summaries/{validate_dataset}/'
    x_values = [8, 16, 24]
    # x_values = [16, 32, 48, 64, 80, 96]
    y_values = [1, 3, 5, 7, 9]

    wb = openpyxl.load_workbook(output_file)
    ws = wb.active
    for x in x_values:
        for y in y_values:
            csv_filename = f'llamaguard_{model}_{x}_{y}_{subfix}.csv'
            csv_filename = dir_path + csv_filename
            if os.path.exists(csv_filename):
                with open(csv_filename, 'r') as file:
                    lines = file.readlines()
                    total_asr = 0
                    row_offset = 1
                    # each dimension
                    for i, line in enumerate(lines[1:]):
                        values = line.strip().split(',')
                        dimension = values[1]
                        ws.cell(row=row_offset, column=1, value=dimension)
                        
                        asr = float(values[2])
                        total_asr += asr
                        formatted_values = "{:.1f}".format(
                            asr * 100,
                        )
                        row = x // 8 + row_offset
                        column = (y+1) // 2
                        ws.cell(row=row, column=column, value=formatted_values)
                        row_offset += 8
                    
                    # average
                    ws.cell(row=row_offset, column=1, value='average')
                    formatted_values = "{:.1f}".format(
                        (total_asr / len(lines[1:])) * 100,
                    )
                    row = x // 8 + row_offset
                    column = (y+1) // 2
                    ws.cell(row=row, column=column, value=formatted_values)
                    
    wb.save(output_file)

def collect_chair_all():
    output_file = '1.xlsx'
    wb = openpyxl.Workbook()
    wb.save(output_file)
    dir_path = f'/data/multimodal_alignment/mm_iti/summaries/{validate_dataset}/'
    x_values = [16, 24, 32, 48, 56]
    # x_values = [8, 16, 24 ]
    # x_values = [16, 32, 48, 64, 80, 96]
    y_values = [3, 5, 7, 9]

    wb = openpyxl.load_workbook(output_file)
    ws = wb.active
    ## both x and y
    for x in x_values:
        for y in y_values:
            csv_filename = f'{model}_{x}_{y}{subfix}.csv'
            csv_filename = dir_path + csv_filename
            if os.path.exists(csv_filename):
                with open(csv_filename, 'r') as file:
                    lines = file.readlines()
                    row_offset = 1
                    # each dimension
                    for i, line in enumerate(lines[1:]):
                        values = line.strip().split(',')
                        values_i = [float(values[j]) * 100 for j in range(1, 4)]
                        values_i.append(float(values[4]))
                        formatted_values = "{:.1f}/{:.1f}/{:.1f}/{:.1f}".format(
                        *(values_i),
                        )
                        row = x // 8 + row_offset
                        column = (y+1) // 2
                        ws.cell(row=row, column=column, value=formatted_values)
                        row_offset += 8
                        # row = row_offset + 1
                        # column = (y+1) // 2
                        # ws.cell(row=row, column=column, value=formatted_values)
                        # row_offset += 2
                    
    wb.save(output_file)
    
def collect_pope_with_gamma():
    output_file = '1.xlsx'
    wb = openpyxl.Workbook()
    wb.save(output_file)
    dir_path = f'/data/multimodal_alignment/mm_iti/summaries/{validate_dataset}/'
    x_values = [0.0, 0.2, 0.5, 1, 2]
    # x_values = [8, 16, 24 ]
    # x_values = [16, 32, 48, 64, 80, 96]
    # y_values = [3, 5, 7]

    wb = openpyxl.load_workbook(output_file)
    ws = wb.active
    ## both x and y
    for idx, x in enumerate(x_values):

        csv_filename = f'{model}_{head}_{alpha}{subfix}_{x}.csv'
        csv_filename = dir_path + csv_filename
        if os.path.exists(csv_filename):
            with open(csv_filename, 'r') as file:
                lines = file.readlines()
                total = [0, 0, 0, 0, 0]
                row_offset = 1
                # each dimension
                for i, line in enumerate(lines[1:]):
                    values = line.strip().split(',')
                    dimension = values[1]
                    ws.cell(row=row_offset, column=1, value=dimension)
                    values_i = [float(values[j]) * 100 for j in range(2, 7)]
                    total = [total[j] + values_i[j] for j in range(0, 5)]
                    formatted_values = "{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}".format(
                    *(values_i),
                    )
                    row = idx + row_offset
                    column = 1
                    # column = y
                    ws.cell(row=row, column=column, value=formatted_values)
                    row_offset += 8
                    # row = row_offset + 1
                    # column = (y+1) // 2
                    # ws.cell(row=row, column=column, value=formatted_values)
                    # row_offset += 2
                    
                                       
    wb.save(output_file)
    

def collect_pope_all():
    output_file = '1.xlsx'
    wb = openpyxl.Workbook()
    wb.save(output_file)
    dir_path = f'/data/multimodal_alignment/mm_iti/summaries/{validate_dataset}/{dataset}/'
    # x_values = [32, 40, 48, 56, 64]
    x_values = [16, 24, 32, 40, 48, 56, 64, 72, 80]
    # x_values = [16, 32, 48, 64]
    y_values = [3, 5, 7]

    wb = openpyxl.load_workbook(output_file)
    ws = wb.active
    ## both x and y
   
    for x in x_values:
        for y in y_values:
            csv_filename = f'{model}_{c}_{x}_{y}{subfix}.csv'
            csv_filename = dir_path + csv_filename
            if os.path.exists(csv_filename):
                with open(csv_filename, 'r') as file:
                    lines = file.readlines()
                    total = [0, 0, 0, 0, 0]
                    row_offset = 1
                    # each dimension
                    for i, line in enumerate(lines[1:]):
                        values = line.strip().split(',')
                        dimension = values[1]
                        ws.cell(row=row_offset, column=1, value=dimension)
                        values_i = [float(values[j]) * 100 for j in range(2, 7)]
                        total = [total[j] + values_i[j] for j in range(0, 5)]
                        formatted_values = "{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}".format(
                        *(values_i),
                        )
                        # formatted_values = "{:.1f}/{:.1f}".format(
                        # values_i[0], values_i[3]
                        # )
                        row = x // 8 + row_offset
                        column = (y+1) // 2
                        # column = y
                        ws.cell(row=row, column=column, value=formatted_values)
                        row_offset += 8
                        # row = row_offset + 1
                        # column = (y+1) // 2
                        # ws.cell(row=row, column=column, value=formatted_values)
                        # row_offset += 2
                    
                    # # average
                    # total = [total[j]/len(lines[1:]) for j in range(0, 5)]
                    # ws.cell(row=row_offset, column=1, value='average')
                    # formatted_values = "{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}".format(
                    #     *(total),
                    # )
                    # # row = row_offset + 1
                    # # column = (y+1) // 2
                    # # ws.cell(row=row, column=column, value=formatted_values)
                    # row = x // 8 + row_offset
                    # column = (y+1) // 2
                    # ws.cell(row=row, column=column, value=formatted_values)
    
    ## only y_values
    # for y in y_values:
    #     csv_filename = f'{model}_{y}{subfix}.csv'
    #     csv_filename = dir_path + csv_filename
    #     if os.path.exists(csv_filename):
    #         with open(csv_filename, 'r') as file:
    #             lines = file.readlines()
    #             total = [0, 0, 0, 0, 0]
    #             row_offset = 1
    #             # each dimension
    #             for i, line in enumerate(lines[1:]):
    #                 values = line.strip().split(',')
    #                 dimension = values[1]
    #                 ws.cell(row=row_offset, column=1, value=dimension)
    #                 values_i = [float(values[j]) * 100 for j in range(2, 7)]
    #                 total = [total[j] + values_i[j] for j in range(0, 5)]
    #                 formatted_values = "{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}".format(
    #                    *(values_i),
    #                 )
    #                 row = row_offset + 1
    #                 column = (y+1) // 2
    #                 ws.cell(row=row, column=column, value=formatted_values)
    #                 row_offset += 2
                
    #             # average
    #             total = [total[j]/len(lines[1:]) for j in range(0, 5)]
    #             ws.cell(row=row_offset, column=1, value='average')
    #             formatted_values = "{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}".format(
    #                 *(total),
    #             )
    #             row = row_offset + 1
    #             column = (y+1) // 2
    #             ws.cell(row=row, column=column, value=formatted_values)
                    
    wb.save(output_file)
    

def collect_mme_all():
    output_file = '1.xlsx'
    wb = openpyxl.Workbook()
    wb.save(output_file)
    dir_path = f'/data/multimodal_alignment/mm_iti/summaries/{validate_dataset}/'
    # x_values = [32, 40, 48, 56, 64]
    x_values = [8, 16, 24, 32, 40, 48, 56, 64, 72]
    # x_values = [16, 32, 48, 64]
    y_values = [1, 3, 5, 7, 9]

    wb = openpyxl.load_workbook(output_file)
    ws = wb.active
    ## both x and y
    for x in x_values:
        for y in y_values:
            csv_filename = f'{model}_{x}_{y}{subfix}.csv'
            csv_filename = dir_path + csv_filename
            if os.path.exists(csv_filename):
                with open(csv_filename, 'r') as file:
                    lines = file.readlines()
                    row_offset = 1
                    # each dimension
                    for i, line in enumerate(lines[1:]):
                        values = line.strip().split(',')
                        values_i = [float(values[j]) for j in range(1, 6)]
                        formatted_values = "{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}".format(
                        *(values_i),
                        )
                        row = x // 8 + row_offset
                        column = (y+1) // 2
                        # column = y
                        ws.cell(row=row, column=column, value=formatted_values)
                        row_offset += 8
                        # row = row_offset + 1
                        # column = (y+1) // 2
                        # ws.cell(row=row, column=column, value=formatted_values)
                        # row_offset += 2
                    
    wb.save(output_file)

def collect_mme_general_all():
    output_file = '1.xlsx'
    wb = openpyxl.Workbook()
    wb.save(output_file)
    dir_path = f'/data/multimodal_alignment/mm_iti/summaries/{validate_dataset}/'
    # x_values = [32, 40, 48, 56, 64]
    x_values = [8, 16, 24, 32, 40, 48, 56, 64, 72, 80]
    # x_values = [16, 32, 48, 64]
    y_values = [1, 3, 5, 7, 9]

    wb = openpyxl.load_workbook(output_file)
    ws = wb.active
    ## both x and y
    for x in x_values:
        for y in y_values:
            csv_filename = f'{model}_{x}_{y}{subfix}.csv'
            csv_filename = dir_path + csv_filename
            if os.path.exists(csv_filename):
                with open(csv_filename, 'r') as file:
                    lines = file.readlines()
                    row_offset = 1
                    # each dimension
                    for i, line in enumerate(lines[1:]):
                        values = line.strip().split(',')
                        values_i = [float(values[j]) for j in range(1, 12)]
                        new_values = copy.deepcopy(values_i)
                        new_values[2] = values_i[4]
                        new_values[3] = values_i[6]
                        new_values[4:6] = values_i[7:9]
                        new_values[6:8] = values_i[2:4]
                        new_values[-3] = values_i[-6]
                        formatted_values = "{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}".format(
                        *(new_values),
                        )
                        row = x // 8 + row_offset
                        column = (y+1) // 2
                        # column = y
                        ws.cell(row=row, column=column, value=formatted_values)
                        row_offset += 8
                        # row = row_offset + 1
                        # column = (y+1) // 2
                        # ws.cell(row=row, column=column, value=formatted_values)
                        # row_offset += 2
                    
    wb.save(output_file)
    
   
     
if __name__ == '__main__':

    if validate_dataset == 'MLLMGuard':
        collect_mllmguard_all()
    elif validate_dataset == 'MM-SafetyBench':
        collect_mmsafetybench_all()
    elif validate_dataset == "POPE":
        collect_pope_all()
    elif validate_dataset == "MME":
        collect_mme_all()
    elif validate_dataset == "MME_general":
        collect_mme_general_all()
        # collect_pope_with_gamma()
    # elif validate_dataset == "POPE_gamma":
    #     collect_pope_with_gamma()
    elif validate_dataset == "CHAIR":
        collect_chair_all()