import jsonlines
from sklearn.metrics import auc

def calc_fscore(prec, rec, beta=1):
    if beta ** 2 * prec + rec == 0:
        return 0
    fscore = (1 + beta ** 2) * prec * rec / (beta ** 2 * prec + rec)
    return fscore

gt = 'RLFN'
dataset = 'desra'
SAN = 'NOSAN'
gauss = 'nogauss'
sr = 'all'
g05 = True


# if g05:
#     methods = [('dists', 0.25) , ('LDL', 0.005), ('bd_jup', 0.1), ('ssim', 0.45), ('ssm_jup', 0.1), ('desra', 0.3), ('nn-20250421-gtgt-e30', 0.15)]
# else:
methods = [('LDL', 0.005), ('ssim', 0.55), ('lpips', 0.25), ('erqa', 0.55), ('pal4inpaint', 0.5), ('pal4vst', 0.5),('dists', 0.25), ('bd_jup', 0.10),  ('ssm_jup', 0.2), ('desra', 0.3), ('nn-20250421-gtgt-e30', 0.15), ('nn-20250421-gtgt-e30', 0.3)]

g05_suf = '_g05' if g05 else ''
jsonl_path = f'../results2/{dataset}_{gt}_{SAN}_{gauss}_desra-bin{g05_suf}.jsonl'


method_names = []
values = []


for method_name_res, threshold_res in methods:
    precisions = []
    recalls = []
    thresholds = []
    row = None
    if method_name_res == 'pal4vst' and gt in {"SPAN", "RLFN"}:
        row = "-\t-\t-"
        if g05:
            row += "\t-"
    else:
        with jsonlines.open(jsonl_path) as reader:
            for obj in reader:
                method_name = list(obj.keys())[0]
                if method_name_res != method_name:
                    continue
                threshold = obj[method_name][sr]['best_threshold_desra']
    
                if g05:
                    precision = obj[method_name][sr]['precision_desra']
                    recall = obj[method_name][sr]['recall_desra']
                    metric = obj[method_name][sr]['iou_desra'] 
                else:
                    precision = obj[method_name][sr]['precision_conf'] * 0.7
                    recall = obj[method_name][sr]['recall_conf'] * 0.7
                    metric = calc_fscore(precision, recall)
                
                
                precisions.append(precision)
                recalls.append(recall)
                thresholds.append(threshold)
                
                if abs(threshold_res - threshold) < 0.00001:
                    if g05:
                        row = f"{precision:.4f}\t{recall:.4f}\t{precision*recall:.4f}\t{metric:.4f}"
                    else:
                        row = f"{precision:.4f}\t{recall:.4f}\t{metric:.4f}"
    if row is None:
        raise ValueError(method_name_res)
    #print(method_name_res)
    #print(min(recalls), max(recalls))
    #print(min(thresholds), max(thresholds))
    if method_name_res in {'pal4inpaint', 'pal4vst'}:
        auc_value = float('NaN')
    else:
        auc_value = auc(*zip(*sorted(zip(recalls, precisions))))
    row = f"{row}\t{auc_value:.4f}"
    #print(method_name_res)
    print(row)
    #print()
