# parsing aubc.csv from detail.csv
import pandas as pd
from io import StringIO
import numpy as np

def AUBC_Zhan(quota, resseq, bsize=1):
    ressum = 0.0
    quota = len(quota)
    if quota % bsize == 0:
        for i in range(len(resseq)-1):
            ressum = ressum + (resseq[i+1] + resseq[i]) * bsize / 2
    else:
        for i in range(len(resseq)-2):
            ressum = ressum + (resseq[i+1] + resseq[i]) * bsize / 2
        k = quota % bsize
        ressum = ressum + ((resseq[-1] + resseq[-2]) * k / 2)
    ressum = round(ressum / quota, 5)
    return ressum

def AUBC(budgets, metrics):
    total_budget = budgets.shape[0]
    # use np.traozid to calculate AUBC 
    ressum = np.trapz(metrics, x=budgets)/total_budget
    return np.round(ressum, 5)

def read_csv_with_format_check(file_path, expected_fields=7, sep='|', header=None):
    with open(file_path, 'r') as file:
        content = file.read()

    # Split the content into lines and remove empty lines
    lines = [line.strip() for line in content.split('\n') if line.strip()]

    # Iterate through lines and check for format errors
    corrected_lines = []
    for line_num, line in enumerate(lines, start=1):
        fields = line.split(sep) # Exclude the first and last empty fields
        if len(fields) != expected_fields:
            # Correct the format error based on the identified pattern
            # case 1: fields[0:3] == fields[3:6]
            if fields[1:3] == fields[4:6]:
                corrected_fields = fields[3:]
                assert len(corrected_fields) == expected_fields
                # corrected_line = f"|{'|'.join(corrected_fields)}|"
                # corrected_lines.append(corrected_line)
                corrected_lines.append(corrected_fields)
            elif len(fields) == 12:
                corrected_fields = fields[:6] + ['']
                assert len(corrected_fields) == expected_fields
                corrected_lines.append(corrected_fields)
                # corrected_line = f"|{'|'.join(corrected_fields)}|"
                # corrected_lines.append(corrected_line)
                corrected_fields = [''] + fields[6:]
                assert len(corrected_fields) == expected_fields
                corrected_lines.append(corrected_fields)
                # corrected_line = f"|{'|'.join(corrected_fields)}|"
                # corrected_lines.append(corrected_line)
            elif len(fields) == 10:
                continue
            elif len(fields) == 15:
                corrected_fields = fields[:6] + ['']
                assert len(corrected_fields) == expected_fields
                corrected_lines.append(corrected_fields)
                # corrected_line = f"|{'|'.join(corrected_fields)}|"
                # corrected_lines.append(corrected_line)
                corrected_fields = fields[8:]
                assert len(corrected_fields) == expected_fields
                corrected_lines.append(corrected_fields)
                # corrected_line = f"|{'|'.join(corrected_fields)}|"
                # corrected_lines.append(corrected_line)
            elif 'INFO' in fields[0]:
                corrected_fields = [''] + fields[1:] + ['']
                assert len(corrected_fields) == expected_fields
                corrected_lines.append(corrected_fields)
                # corrected_lines.append(corrected_line)
            elif set(fields) == {''}:
                continue
            else:
                breakpoint()
        else:
            assert len(fields) == expected_fields
            corrected_lines.append(fields)
            # corrected_lines.append(line)

    # Create a new CSV file with corrected lines
    # corrected_content = '\n'.join(corrected_lines)
    # corrected_csv = pd.read_csv(StringIO(corrected_content), sep='|', header=None)
    corrected_csv = pd.DataFrame(corrected_lines)
    # replace '' to np.nan
    corrected_csv = corrected_csv.replace('', np.nan)
    # convert column 3 to float
    corrected_csv[3] = corrected_csv[3].astype(float)
    # remove rows with invalid number, its column 3 is not in [0.5, 1]
    corrected_csv = corrected_csv[corrected_csv[3].between(0.5, 1)]
    # convert column 2 to int
    corrected_csv[2] = corrected_csv[2].astype(int)
    return corrected_csv

# deal with clean and breast
file_prefix = 'detail/'
# file_path = 'australian-lal-zhan-RandomForest-RandomForest-RS_noFix_scale'
file_path = 'bioresponse-lal-XGBoost-XGBoost-RS_noFix_scale'
file_path = 'spambase-lal-XGBoost-XGBoost-RS_noFix_scale'
detail = read_csv_with_format_check(f'{file_prefix}{file_path}-detail.csv')

# group by column 1 and calculate aubc by column 2 and column 3
detail_groupby1 = detail.groupby(1)
aubc = detail_groupby1.apply(lambda x: AUBC(x[2].values, x[3].values))
# convert Series to DataFrame and convert index to int
aubc = aubc.to_frame()
aubc.index = aubc.index.astype(int)
aubc.columns = ['res_tst_score']

# read original aubc.csv
# aubc = pd.read_csv(f'{file_prefix}/aubc/{file_path}-aubc.csv', index_col=0)

# merge aubc and detail_groupby1_aubc by index and outer join
# aubc = aubc.merge(detail_groupby1_aubc, left_index=True, right_index=True, how='outer')
# fill NAN in res_tst_score column by AUBC column
# aubc['res_tst_score'] = aubc['res_tst_score'].fillna(aubc['AUBC'])
# drop AUBC column
aubc.index.name = 'res_expno'
# save to csv
aubc.to_csv(f'{file_path}-aubc.csv')