import pandas as pd
import openpyxl
import os

validate_dataset='MedHEval'
model='llava_med_v1.5'
# categories=("privacy" "bias" "toxicity" "hallucination" "legality")
c='xray'  # ['slake', 'rad', 'mimic_cxr', 'xray']
hallu_type='visual_misinterpretation' # visual_misinterpretation  knowledge_deficiency context_misalignment
answer_type='close'
probe_dataset='Mimic_Knowledge' # GEMeX SLAKE Mimic_Knowledge
# pos='I+Q+RD'
# neg='I+Q_onlyr'
pos='I+Q+RD'
neg='I+Q_onlyr'

pos2=''
neg2=''
# pos2='I+Q+RD'
# neg2='I+Q'

subfix='_len'
# subfix='_train_I+Q;C_p2+Q_end31_youare_YR'
head = 40
alpha = 7

loc_dict = {}

def collect_medheval_all():
    output_file = '1.xlsx'
    wb = openpyxl.Workbook()
    wb.save(output_file)
    dir_path = f'/root/project/summaries/{hallu_type}/'
    # x_values = [32, 40, 48, 56, 64]
    x_values = [8, 16, 24, 32, 40]
    # x_values = [16, 32, 48, 64]
    y_values = [1, 3, 5, 7, 9]

    wb = openpyxl.load_workbook(output_file)
    ws = wb.active
    ## both x and y
    for x in x_values:
        for y in y_values:
            if not neg2 == '' and not pos2 == '': 
                csv_filename = f'{model}_{answer_type}_{c}_{x}_{y}_{probe_dataset}_{neg};{pos}_{neg2};{pos2}{subfix}.csv'
            else:
                csv_filename = f'{model}_{answer_type}_{c}_{x}_{y}_{probe_dataset}_{neg};{pos}{subfix}.csv'
            csv_filename = dir_path + csv_filename
            print(csv_filename, os.path.exists(csv_filename))
            if os.path.exists(csv_filename):
                with open(csv_filename, 'r') as file:
                    lines = file.readlines()
                    row_offset = 1
                    # each dimension
                    for i, line in enumerate(lines[1:]):
                        values = line.strip().split(',')
                        dimension = values[1]
                        ws.cell(row=row_offset, column=1, value=dimension)
                        if hallu_type == "visual_misinterpretation":
                            if answer_type == "close":
                                formatted_values = "{}/{}/{}/{}/{}/{}".format(
                                *(values[2:8]),
                                )
                            else:
                                formatted_values = "{}/{}/{}/{}/{}/{}/{}".format(
                                *(values[1:8]),
                                )
                        elif hallu_type == "knowledge_deficiency" or hallu_type == "context_misalignment":
                            if answer_type == "close":
                                formatted_values = "{}/{}".format(*(values[2:]))
                            else:
                                formatted_values = "{}/{}/{}/{}/{}/{}/{}".format(
                                *(values[1:8]),
                            )
                        row = x // 8 + row_offset
                        column = (y+1) // 2
                        # column = y
                        ws.cell(row=row, column=column, value=formatted_values)
                        row_offset += 8
                        # row = row_offset + 1
                        # column = (y+1) // 2
                        # ws.cell(row=row, column=column, value=formatted_values)
                        # row_offset += 2
                    
                    # # average
                    # total = [total[j]/len(lines[1:]) for j in range(0, 5)]
                    # ws.cell(row=row_offset, column=1, value='average')
                    # formatted_values = "{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}".format(
                    #     *(total),
                    # )
                    # # row = row_offset + 1
                    # # column = (y+1) // 2
                    # # ws.cell(row=row, column=column, value=formatted_values)
                    # row = x // 8 + row_offset
                    # column = (y+1) // 2
                    # ws.cell(row=row, column=column, value=formatted_values)
    
    ## only y_values
    # for y in y_values:
    #     csv_filename = f'{model}_{y}{subfix}.csv'
    #     csv_filename = dir_path + csv_filename
    #     if os.path.exists(csv_filename):
    #         with open(csv_filename, 'r') as file:
    #             lines = file.readlines()
    #             total = [0, 0, 0, 0, 0]
    #             row_offset = 1
    #             # each dimension
    #             for i, line in enumerate(lines[1:]):
    #                 values = line.strip().split(',')
    #                 dimension = values[1]
    #                 ws.cell(row=row_offset, column=1, value=dimension)
    #                 values_i = [float(values[j]) * 100 for j in range(2, 7)]
    #                 total = [total[j] + values_i[j] for j in range(0, 5)]
    #                 formatted_values = "{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}".format(
    #                    *(values_i),
    #                 )
    #                 row = row_offset + 1
    #                 column = (y+1) // 2
    #                 ws.cell(row=row, column=column, value=formatted_values)
    #                 row_offset += 2
                
    #             # average
    #             total = [total[j]/len(lines[1:]) for j in range(0, 5)]
    #             ws.cell(row=row_offset, column=1, value='average')
    #             formatted_values = "{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}".format(
    #                 *(total),
    #             )
    #             row = row_offset + 1
    #             column = (y+1) // 2
    #             ws.cell(row=row, column=column, value=formatted_values)
                    
    wb.save(output_file)
    
def collect_amber_all():
    output_file = '1.xlsx'
    wb = openpyxl.Workbook()
    wb.save(output_file)
    dir_path = f'/root/wtb/multimodal_alignment/mm_iti/summaries/{validate_dataset}/'
    # x_values = [32, 40, 48, 56, 64]
    x_values = [8, 16, 24, 32, 40, 48, 56, 64, 72]
    # x_values = [16, 32, 48, 64]
    y_values = [1, 3, 5, 7, 9]

    wb = openpyxl.load_workbook(output_file)
    ws = wb.active
    ## both x and y
    for x in x_values:
        for y in y_values:
            csv_filename = f'{model}_{x}_{y}{subfix}.csv'
            csv_filename = dir_path + csv_filename
            if os.path.exists(csv_filename):
                with open(csv_filename, 'r') as file:
                    lines = file.readlines()
                    total = [0, 0, 0, 0, 0, 0, 0, 0]
                    row_offset = 1
                    # each dimension
                    for i, line in enumerate(lines[1:]):
                        values = line.strip().split(',')
                        dimension = values[1]
                        ws.cell(row=row_offset, column=1, value=dimension)
                        values_i = [float(values[j]) for j in range(2, 6)]
                        formatted_values = "{:.1f}/{:.1f}/{:.1f}/{:.1f}".format(
                        *(values_i),
                        )
                        row = x // 8 + row_offset
                        column = (y+1) // 2
                        # column = y
                        ws.cell(row=row, column=column, value=formatted_values)
                        row_offset += 8
                        # row = row_offset + 1
                        # column = (y+1) // 2
                        # ws.cell(row=row, column=column, value=formatted_values)
                        # row_offset += 2
                    
                    # # average
                    # total = [total[j]/len(lines[1:]) for j in range(0, 5)]
                    # ws.cell(row=row_offset, column=1, value='average')
                    # formatted_values = "{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}".format(
                    #     *(total),
                    # )
                    # # row = row_offset + 1
                    # # column = (y+1) // 2
                    # # ws.cell(row=row, column=column, value=formatted_values)
                    # row = x // 8 + row_offset
                    # column = (y+1) // 2
                    # ws.cell(row=row, column=column, value=formatted_values)
    
    ## only y_values
    # for y in y_values:
    #     csv_filename = f'{model}_{y}{subfix}.csv'
    #     csv_filename = dir_path + csv_filename
    #     if os.path.exists(csv_filename):
    #         with open(csv_filename, 'r') as file:
    #             lines = file.readlines()
    #             total = [0, 0, 0, 0, 0]
    #             row_offset = 1
    #             # each dimension
    #             for i, line in enumerate(lines[1:]):
    #                 values = line.strip().split(',')
    #                 dimension = values[1]
    #                 ws.cell(row=row_offset, column=1, value=dimension)
    #                 values_i = [float(values[j]) * 100 for j in range(2, 7)]
    #                 total = [total[j] + values_i[j] for j in range(0, 5)]
    #                 formatted_values = "{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}".format(
    #                    *(values_i),
    #                 )
    #                 row = row_offset + 1
    #                 column = (y+1) // 2
    #                 ws.cell(row=row, column=column, value=formatted_values)
    #                 row_offset += 2
                
    #             # average
    #             total = [total[j]/len(lines[1:]) for j in range(0, 5)]
    #             ws.cell(row=row_offset, column=1, value='average')
    #             formatted_values = "{:.1f}/{:.1f}/{:.1f}/{:.1f}/{:.1f}".format(
    #                 *(total),
    #             )
    #             row = row_offset + 1
    #             column = (y+1) // 2
    #             ws.cell(row=row, column=column, value=formatted_values)
                    
    wb.save(output_file)
    
    

if __name__ == '__main__':

    if validate_dataset == 'MedHEval':
        collect_medheval_all()