import csv
import os
import random
import gzip
import copy
import math
from glob import glob
from pathlib import Path
from collections import defaultdict
from sklearn.model_selection import train_test_split
from scipy.stats import spearmanr, pearsonr
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm import tqdm
import jsonlines
from scipy import ndimage

MIN_AREA = 4000
TEST_DATASET = '<path to test dataset>'
DESRA_DATASET = '<path to DeSRA dataset>'
HEATMAPS_PATH = '<path to precalculated heatmaps>'
RES_DIR = '<directory to sabe results>'

class Metric:
    def __init__(self):
        self.data = {}
        
    def __setitem__(self, mask_fn, val):
        self.data[mask_fn] = val
        
    def get_values(self, sr=None):
        values = []
        for mask_fn, val in self.data.items():
            if sr is None or sr == 'all':
                values.append(val)
            else:
                spl_name = mask_fn.split('@')
                if sr == spl_name[spl_name.index('SR') + 1]:
                    values.append(val)
        return values
        
    def aggregate(self, agg=sum, sr=None):
        return agg(self.get_values(sr=sr))
    
    def slice_sr(self, sr):
        copy_metric = copy.deepcopy(self)
        tmp_data = {}
        for mask_fn, val in self.data.items():
            if sr is None or sr == 'all':
                tmp_data[mask_fn] = val
            else:
                spl_name = mask_fn.split('@')
                if sr == spl_name[spl_name.index('SR') + 1]:
                    tmp_data[mask_fn] = val
        copy_metric.data = tmp_data
        return copy_metric
        

class MetricAggregator:
    
    def __init__(self, metric_names : list[str]):
        self.metrics = {metric_name : defaultdict(Metric) for metric_name in metric_names}
        
    def __getitem__(self, metric_key):
        if type(metric_key) is str:
            return self.metrics[metric_key][None]
        elif type(metric_key) is tuple:
            metric_name, threshold = metric_key
            return self.metrics[metric_name][threshold]
        else:
            raise KeyError()
            
    def __setitem__(self, metric_key, val):
        self[metric_key][None] = val
        
    def __delitem__(self, metric_name):
        del self.metrics[metric_name]
        
    def add_metric(self, metric_name):
        if metric_name in self.metrics:
            raise ValueError(f"{metric_name} already was in MetricAggregator")
        self.metrics[metric_name] = defaultdict(Metric)
    
    def get_thresholds(self, metric_name):
        return self.metrics[metric_name].keys()
    
    def aggregate(self, metric_name, agg=sum):
        if set(self.metrics[metric_name]) == {None}:
            return self.metrics[metric_name][None].aggregate(agg=agg)
        else:
            return {threshold : metric.aggregate(agg=agg) for threshold, metric in self.metrics[metric_name].items()}
        
    def slice_sr(self, sr):
        copy_aggregator = copy.deepcopy(self)
        for metric_name, threshold_metrics_dct in copy_aggregator.metrics.items():
            for threshold, metric in threshold_metrics_dct.items():
                copy_aggregator.metrics[metric_name][threshold] = metric.slice_sr(sr)
        return copy_aggregator
        



def binirize_desra(contrast_seg, contrast_threshold, area_threshold):
    #contrast_seg = 1 - contrast_seg
    contrast_seg_mask = np.zeros(contrast_seg.shape)
    contrast_seg_mask[contrast_seg > contrast_threshold] = 1


    kernel = np.ones((5, 5), np.uint8)
    erosion = cv2.erode(contrast_seg_mask, kernel, iterations=1)
    dilation = cv2.dilate(erosion, kernel, iterations=3)
    dst = ndimage.binary_fill_holes(dilation, structure=np.ones((3, 3))).astype('uint8')


    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(dst, connectivity=8)
    image_filtered = np.zeros_like(dst).astype(bool)
    for (i, label) in enumerate(np.unique(labels)):
        if label == 0:
            continue
        if stats[i][-1] > area_threshold:
            image_filtered[labels == i] = True
    
    return image_filtered

def bool_confusion_matrix(y_true, y_pred):
    return np.bincount(y_true * 2 + y_pred, minlength=4)


def heatmap2bboxes(heatmap):
    '''
    return list of sorted by size bboxes from binarised heatmap
    '''
    cnts = cv2.findContours(heatmap, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
    cnts = cnts[0] if len(cnts) == 2 else cnts[1]

    bboxes = []
    for c in cnts:
        x, y, w, h = cv2.boundingRect(c)
        bbox = {"x1": x, "x2": x+w, "y1": y, "y2": y+h}
        bboxes.append(bbox)
    
    return sorted(bboxes, key=(lambda b: max(b["x2"]-b["x1"], b["y2"]-b["y1"])), reverse=True)

def get_conf_dict(conf_csv):
    conf_dict = {}
    with open(conf_csv) as f:
        reader = csv.DictReader(f)
        for row in reader:
            conf_dict[row['mask_fn']] = {
                'ratio_bi' : float(row['ratio_bi']),
                'folder' : row['folder'],
                'sr_fn' : row['sr_fn']
                }
    return dict(sorted(conf_dict.items()))

def get_img_pathes(images_path, gt='gt', rf='bicubic'):
    img_pathes = {}
    subfolders = sorted(os.listdir(images_path))
    for name in tqdm(subfolders):
        if gt == 'SPAN':
            gt_file = os.path.join(images_path, name, f"{name}@RF@SPAN_x4.png")
        elif gt == 'RLFN':
            gt_file = os.path.join(images_path, name, f"{name}@SR@RLFN_x4.png")
        else:
            gt_file = os.path.join(images_path, name, f"{name}.png")
        if rf == 'SPAN':
            rf_file = os.path.join(images_path, name, f"{name}@RF@SPAN_x4.png")
        else:
            rf_file = os.path.join(images_path, name, f"{name}@RF@bicubic_x4.png")
        
        if os.path.isfile(gt_file) and os.path.isfile(rf_file):
            for img_fn in sorted(os.listdir(os.path.join(images_path, name))):
                if '@SR@' in img_fn and '@' not in img_fn.split('@SR@', 1)[1]:
                    sr_file = os.path.join(images_path, name, img_fn)
                    img_pathes[len(img_pathes)] = (
                        gt_file,
                        sr_file,
                        rf_file,
                    )
    return img_pathes

def get_desra_dataset(dataset_path, sr_names : str | list | None = None, rf_name='DESRA', amount=-1, shuffle=False, mask_suffix='dist'):
    DESRA_SR_NAMES = ['RealESRGAN', 'SwinIR', 'LDL']
    if sr_names is None:
        sr_names = DESRA_SR_NAMES
    elif isinstance(sr_names, str):
        sr_names = [sr_names]
    
    asserted_names = [x for x in sr_names if x not in DESRA_SR_NAMES]
    assert len(asserted_names) == 0, f'{sr_names} must be in {DESRA_SR_NAMES}'
    
    raw_imgs_pathes = sorted(glob(f'{dataset_path}/*'))
    
    if shuffle:
        random.shuffle(raw_imgs_pathes)

    img_pathes = {}
    for raw_imgs_path in tqdm(raw_imgs_pathes[:amount]):
        raw_name = Path(raw_imgs_path).stem
        srs = [Path(x).stem.split('@')[-1] for x in glob(f"{raw_imgs_path}/{raw_name}@SR@*.*")]
        mss = [Path(x).stem.split('@')[-1] for x in glob(f"{raw_imgs_path}/{raw_name}@MS@*.*")]

        sr_valid = list(set(srs).intersection(set(mss)))

        for sr_name in sr_valid:
            if sr_name not in sr_names:
                continue
            
            item = {
                'lr_path' : f"{raw_imgs_path}/{raw_name}@LR@gt_x4.png",
                'sr_path' : f"{raw_imgs_path}/{raw_name}@SR@{sr_name}.png",
                'rf_path' : f"{raw_imgs_path}/{raw_name}@RF@{rf_name}_x4.png",
                'bc_path' : f"{raw_imgs_path}/{raw_name}@RF@bicubic_x4.png",
                'ms_path' : f"{raw_imgs_path}/{raw_name}@MS@{sr_name}.png",
            }
            if rf_name.lower() == 'desra':
                item['rf_path'] = f"{raw_imgs_path}/{raw_name}@RF@{sr_name}_x4.png"
                
                
            gt_file = item['rf_path']
            sr_file = item['sr_path']
            rf_file = item['bc_path']
            
            # gt_file = f"{raw_imgs_path}/{raw_name}.png"
            # sr_file = item['sr_path']
            # rf_file = item['rf_path']
            
            img_pathes[f"{raw_name}@SR@{sr_name}@MS@{mask_suffix}.png"] = (
                gt_file,
                sr_file,
                rf_file
            )


    return img_pathes

def get_img_pathes_gt(images_path, conf_dict, subset='full', gt='gt', rf='bicubic'):
    img_pathes = {}
    conf_dict = dict(sorted(conf_dict.items()))
    folders = [x['folder'] for x in conf_dict.values()]
    if subset == 'train':
        folders, _ = train_test_split(folders, train_size=0.5, random_state=48)
    elif subset == 'test':
        _, folders = train_test_split(folders, train_size=0.5, random_state=48)
    folders = set(folders)
    
    for mask_fn in conf_dict:
        data = conf_dict[mask_fn]
        folder = data['folder']
        if folder not in folders:
            continue
        sr_fn = data['sr_fn']
        if gt == 'SPAN':
            gt_file = os.path.join(images_path, folder, f"{folder}@RF@SPAN_x4.png")
        elif gt == 'RLFN':
            gt_file = os.path.join(images_path, folder, f"{folder}@RF@RLFN_x4.png")
        else:
            gt_file = os.path.join(images_path, folder, f"{folder}.png")
        if rf == 'SPAN':
            rf_file = os.path.join(images_path, folder, f"{folder}@RF@SPAN_x4.png")
        else:
            rf_file = os.path.join(images_path, folder, f"{folder}@RF@bicubic_x4.png")
        sr_file = os.path.join(images_path, folder, sr_fn)
        img_pathes[mask_fn] = (
            gt_file,
            sr_file,
            rf_file,
        )

    return img_pathes


    
def compute_global_confusion_matrix(method, thresholds, img_pathes, conf_dict, min_area=None, p_thres=0.5):

    metric_aggregator = MetricAggregator(['tp', 'fp', 'tn', 'fn', 'tp_conf', 'gt_conf', 'pred_conf', 'pred_conf_gt_area',
                                          'iou', 'real_artifacts_nums', 'detected_gt_artifacts_nums',
                                          'total_detected_artifacts_nums', 'total_gt_artifacts_nums'
                                          ])

    if thresholds is None:
        raise ValueError()
    print("Iterate over heatmaps:")
    #for (mask_fn, (hr_path, sr_path, rf_path)), heatmap in tqdm(heatmap_dct.items()):
    for mask_fn, (hr_path, sr_path, rf_path) in tqdm(img_pathes.items()):
        heatmap = method(hr_path, sr_path, rf_path)
        gt = (cv2.imread(os.path.join(os.path.dirname(sr_path), mask_fn), 0) == 255).astype(np.uint8)
        heatmap = heatmap[5:-5,5:-5]
        gt = gt[5:-5,5:-5]
        if np.count_nonzero(gt) == 0:
            print('empty GT:', mask_fn)


        gt_map = gt > 0
        gt_mask = (gt_map * 255).astype('uint8')
        gt_mask_num_labels, gt_mask_labels, gt_mask_stats, gt_mask_centroids = cv2.connectedComponentsWithStats(gt_mask, connectivity=8)
        total_gt_artifacts_nums = (gt_mask_num_labels - 1)
        metric_aggregator['total_gt_artifacts_nums'][mask_fn] = total_gt_artifacts_nums
        
        
        metric_aggregator['pred_conf_gt_area'][mask_fn] = np.mean(heatmap, where=gt_map)
        metric_aggregator['gt_conf'][mask_fn] = conf_dict[mask_fn]['ratio_bi']
        gt = gt.flatten()
        for threshold in thresholds:
            #heatmap_bin = (heatmap > threshold)
            heatmap_bin = binirize_desra(heatmap, threshold, min_area)
            if False and min_area is not None:
                contours, hierarchy = cv2.findContours(heatmap_bin.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_TC89_L1)
                contours = [c for c in contours if cv2.contourArea(c) >= min_area]
                heatmap_bin = (cv2.drawContours(np.zeros(heatmap_bin.shape), contours, -1, color=(1, 1, 1), thickness=cv2.FILLED) == 1)
            
            tn, fp, fn, tp = bool_confusion_matrix(gt, heatmap_bin.flatten())
            metric_aggregator['tp', threshold][mask_fn] = tp
            metric_aggregator['tn', threshold][mask_fn] = tn
            metric_aggregator['fp', threshold][mask_fn] = fp
            metric_aggregator['fn', threshold][mask_fn] = fn
            metric_aggregator['tp_conf', threshold][mask_fn] = tp * (conf_dict[mask_fn]['ratio_bi'] - 0.3)

            
            if len(pred_conf_map := heatmap[heatmap_bin]):
                metric_aggregator['pred_conf', threshold][mask_fn] = np.mean(pred_conf_map)
            else:
                metric_aggregator['pred_conf', threshold][mask_fn] = 0

            input_mask = (heatmap_bin * 255).astype('uint8')
            

            detected_mask_num_labels, detected_mask_labels, detected_mask_stats, detected_mask_centroids = cv2.connectedComponentsWithStats(input_mask, connectivity=8)
            total_detected_artifacts_nums = (detected_mask_num_labels - 1)


            # loop
            image_filtered = np.zeros_like(input_mask)
            real_artifacts_nums = 0
            for (i, label) in enumerate(np.unique(detected_mask_labels)):
                if label == 0:
                    continue
                image_filtered = np.zeros_like(input_mask)
                image_filtered[detected_mask_labels == i] = 1
                overlap_mask = np.asarray(image_filtered, dtype="int32") & np.asarray(gt_mask / 255, dtype="int32")
                ratio = np.sum(overlap_mask) / (np.sum(image_filtered) + 1e-9)
                if ratio >= p_thres:
                    real_artifacts_nums += 1
                    
                    
            image_filtered = np.zeros_like(input_mask)
            detected_gt_artifacts_nums = 0
            for (i, label) in enumerate(np.unique(gt_mask_labels)):
                if label == 0:
                    continue
                image_filtered = np.zeros_like(input_mask)
                image_filtered[gt_mask_labels == i] = 1
                overlap_mask = np.asarray(image_filtered, dtype="int32") & np.asarray(input_mask / 255, dtype="int32")
                ratio = np.sum(overlap_mask) / (np.sum(image_filtered) + 1e-9)
                if ratio >= p_thres:
                    detected_gt_artifacts_nums += 1
                    
            metric_aggregator['total_detected_artifacts_nums', threshold][mask_fn] = total_detected_artifacts_nums
            metric_aggregator['real_artifacts_nums', threshold][mask_fn] = real_artifacts_nums
            metric_aggregator['detected_gt_artifacts_nums', threshold][mask_fn] = detected_gt_artifacts_nums
            
            
                    
            union =        np.logical_or (gt_mask, input_mask).sum()
            intersection = np.logical_and(gt_mask, input_mask).sum()
            iou = intersection / (union + 1e-9)
            metric_aggregator['iou', threshold][mask_fn] = float(iou)
            

    return metric_aggregator



def compute_prec_rec(metric_aggregator : MetricAggregator):
    # all_precisions_conf = defaultdict(float)
    # all_recalls_conf = defaultdict(float)
    metric_aggregator.add_metric('precision_conf')
    metric_aggregator.add_metric('recall_conf')
    
    for threshold in metric_aggregator.get_thresholds('tp'):
        tn = metric_aggregator['tn', threshold].aggregate()
        fp = metric_aggregator['fp', threshold].aggregate()
        fn = metric_aggregator['fn', threshold].aggregate()
        tp = metric_aggregator['tp', threshold].aggregate()
        tp_conf = metric_aggregator['tp_conf', threshold].aggregate()
        

        all_precision_conf = tp_conf / (tp + fp) if tp + fp else 0
        if math.isnan(all_precision_conf):
            all_precision_conf = 0
        all_recall_conf = tp_conf / (tp + fn) if tp + fn else 0
        if math.isnan(all_recall_conf):
            all_recall_conf = 0
            
        #neg_precision_conf = 0.5 * tn / (tn + fn)
        #neg_recall_conf = 0.5 * tn / (tn + fp)
        
        metric_aggregator['precision_conf', threshold] = all_precision_conf
        metric_aggregator['recall_conf', threshold] = all_recall_conf
    return metric_aggregator

all_real_artifacts_nums = defaultdict(int)
all_detected_gt_artifacts_nums = defaultdict(int)
all_total_detected_artifacts_nums = defaultdict(int)
all_total_gt_artifacts_nums = defaultdict(int)

    
def compute_desra_metrics(metric_aggregator : MetricAggregator):
    

    metric_aggregator.add_metric('precision_desra')
    metric_aggregator.add_metric('recall_desra')
    for threshold in metric_aggregator.get_thresholds('iou'):
        
        metric_aggregator['precision_desra', threshold] = (metric_aggregator['real_artifacts_nums', threshold].aggregate() /
            (metric_aggregator['total_detected_artifacts_nums', threshold].aggregate() + 1e-9))
        
        metric_aggregator['recall_desra', threshold] = (metric_aggregator['detected_gt_artifacts_nums', threshold].aggregate() /
            (metric_aggregator['total_gt_artifacts_nums'].aggregate() + 1e-9))
        
        
    
    return metric_aggregator

def select_optimal_threshold(metric_aggregator : MetricAggregator):
    return max(metric_aggregator.aggregate('fscore_conf').items(), key=lambda x: x[1])[0]

def calc_fscore(prec, rec, beta=1):
    try:
        fscore = (1 + beta ** 2) * prec * rec / (beta ** 2 * prec + rec)
    except ZeroDivisionError:
        fscore = 0
    return fscore

def calc_all_fscore(metric_aggregator : MetricAggregator):
    metric_aggregator.add_metric('fscore_conf')
    for threshold in metric_aggregator.get_thresholds('precision_conf'):
        metric_aggregator['fscore_conf', threshold] = calc_fscore(
            metric_aggregator['precision_conf', threshold].aggregate(),
            metric_aggregator['recall_conf', threshold].aggregate(),
            )

def get_best_threshold_fscore(method, img_pathes, conf_dict, thresholds=None, min_area=None, sr_list=['all']):

    metric_aggregator = compute_global_confusion_matrix(method, thresholds, img_pathes, conf_dict, min_area=min_area)
   
        
    full_res = {}
    
    for sr in sr_list:
        metrics_res = {}
        metric_aggregator_sr = metric_aggregator.slice_sr(sr)
        compute_prec_rec(metric_aggregator_sr)
        calc_all_fscore(metric_aggregator_sr)
        best_threshold_conf = select_optimal_threshold(metric_aggregator_sr)
        
        compute_desra_metrics(metric_aggregator_sr)
        
        
        fscore = metric_aggregator_sr['fscore_conf', best_threshold_conf].aggregate()
    
        
        precision_conf = metric_aggregator_sr['precision_conf', best_threshold_conf].aggregate()
        recall_conf = metric_aggregator_sr['recall_conf', best_threshold_conf].aggregate()
        
        try:
            pcc = pearsonr(metric_aggregator_sr['gt_conf'].get_values(), metric_aggregator_sr['pred_conf', best_threshold_conf].get_values()).statistic
        except ValueError:
            pcc = 0
        try:
            srcc = spearmanr(metric_aggregator_sr['gt_conf'].get_values(), metric_aggregator_sr['pred_conf', best_threshold_conf].get_values()).statistic
        except ValueError:
            srcc = 0
            
        try:    
            pcc_gt_area = pearsonr(metric_aggregator_sr['gt_conf'].get_values(), metric_aggregator_sr['pred_conf_gt_area'].get_values()).statistic
        except ValueError:
            pcc_gt_area = 0
        try:
            srcc_gt_area = spearmanr(metric_aggregator_sr['gt_conf'].get_values(), metric_aggregator_sr['pred_conf_gt_area'].get_values()).statistic
        except ValueError:
            srcc_gt_area = 0
            
        best_threshold_iou, iou_score = max(metric_aggregator_sr.aggregate('iou', agg=np.mean).items(), key=lambda x: x[1])
    
        precision_desra = metric_aggregator_sr['precision_desra', best_threshold_iou].aggregate() 
        recall_desra = metric_aggregator_sr['recall_desra', best_threshold_iou].aggregate() 
        metrics_res.update({
            'best_threshold_conf' : best_threshold_conf,
            'f1score' : fscore,
            'precision_conf' : precision_conf,
            'recall_conf' : recall_conf,
            'pcc' : pcc,
            'srcc' : srcc,
            'pcc_gt_area' : pcc_gt_area,
            'srcc_gt_area' : srcc_gt_area,
            'best_threshold_desra' : best_threshold_iou,
            'iou_desra' : iou_score,
            'precision_desra' : precision_desra,
            'recall_desra' : recall_desra,
            })
        full_res[sr] = metrics_res
    return full_res


def evaluate_method(method, threshold, img_pathes, vis_path=None, contours_path=None, heatmaps_path=None, heatmaps_npy_path=None, min_area=None, skip_exist=False):

    for mask_fn, (hr_path, sr_path, rf_path) in img_pathes.items():
        print(hr_path, sr_path, rf_path, sep='\n')
        print()
        sr_raw_name = Path(sr_path).stem
        if skip_exist:
            if heatmaps_npy_path and os.path.isfile(os.path.join(heatmaps_npy_path, f'{sr_raw_name}.npy')):
                print('skip', os.path.join(heatmaps_npy_path, f'{sr_raw_name}.npy'))
                continue
        try:
            heatmap = method(hr_path, sr_path, rf_path)
        except Exception as e:
            print(f"error evaluating method on {sr_raw_name} ({Path(hr_path).stem}, {Path(sr_path).stem}, {Path(rf_path).stem})")
            raise e
        heatmap_bin = ((heatmap > threshold) * 1.0).astype('uint8')
        if min_area is not None:
            contours, hierarchy = cv2.findContours(heatmap_bin, cv2.RETR_LIST, cv2.CHAIN_APPROX_TC89_L1)
            contours = [c for c in contours if cv2.contourArea(c) >= min_area]
            heatmap_bin = cv2.drawContours(np.zeros(heatmap_bin.shape), contours, -1, color=(255, 255, 255), thickness=cv2.FILLED).astype('uint8') // 255
        if vis_path:
            os.makedirs(vis_path, exist_ok=True)
            sr_mat = cv2.imread(sr_path)
            sr_mat = cv2.cvtColor(sr_mat, cv2.COLOR_BGR2RGB)
            hr_mat = cv2.imread(hr_path)
            hr_mat = cv2.cvtColor(hr_mat, cv2.COLOR_BGR2RGB)
            contours, hierarchy = cv2.findContours(heatmap_bin, cv2.RETR_LIST, cv2.CHAIN_APPROX_TC89_L1)
            contoured = cv2.drawContours(sr_mat.copy(), contours, -1, (255, 0, 0), 2)
            fig, ax = plt.subplots(2, 2, figsize=(6, 5))
            ax[0, 0].imshow(hr_mat)
            ax[0, 0].set_title("GT image", fontsize=10)
            ax[0, 1].imshow(sr_mat)
            ax[0, 1].set_title("SR image", fontsize=10)
            ax[1, 0].imshow(heatmap_bin)
            ax[1, 0].set_title(f"Binarized heatmap, threshold = {threshold:.3f}", fontsize=10)
            ax[1, 1].imshow(contoured)
            ax[1, 1].set_title("Contoured SR image", fontsize=10)
            #plt.show()
            plt.savefig(os.path.join(vis_path, f'{sr_raw_name}.jpg'), dpi=400)
            plt.close(fig)
            plt.clf()
        if contours_path:
            os.makedirs(contours_path, exist_ok=True)
            cv2.imwrite(os.path.join(contours_path, f'{sr_raw_name}.png'), heatmap_bin * 255)
        if heatmaps_path:
            os.makedirs(heatmaps_path, exist_ok=True)
            fig, ax = plt.subplots(1, 2, layout='constrained')
            ax[0].imshow(sr_mat)
            ax[1].imshow(heatmap, cmap='magma')
            fig.set_size_inches(9, 6)
            colorbar = fig.colorbar(mpl.cm.ScalarMappable(norm=mpl.colors.Normalize(0, heatmap.max()), cmap='magma'),
                     ax=ax[1], orientation='vertical',fraction=0.046, pad=0.04)
            colorbar.set_ticks([0,  1])
            plt.savefig(os.path.join(heatmaps_path, f'{sr_raw_name}.jpg'))
            fig.clf()
        if heatmaps_npy_path:
            os.makedirs(heatmaps_npy_path, exist_ok=True)
            f = gzip.GzipFile(os.path.join(heatmaps_npy_path, f'{sr_raw_name}.npy'), "w")
            np.save(file=f, arr=heatmap)
            f.close()
            
        yield mask_fn, sr_path, heatmap_bin, heatmap




if __name__ == '__main__':


    from scipy.ndimage import gaussian_filter
    from functools import partial
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('--gt', type=str, choices=['gt', 'DESRA', 'SPAN', 'RLFN'], required=True)
    parser.add_argument('--dataset', type=str, choices=['test', 'desra'], required=True)
    parser.add_argument('--gauss', action='store_true', help='Apply Gaussian filter to final heatmap')
    parser.add_argument('--san', action='store_true', help='Use semantic masking')
    parser.add_argument('--g05', action='store_true', help='Use prominent subset (prominence >= 0.5)')
    args =  parser.parse_args()
    
    SAN = 'SAN' if args.san else 'NOSAN'
    gt = args.gt
    gauss = 'gauss' if args.gauss else 'nogauss'
    
    nn_gt_name = 'MSESR' if gt.lower() == 'desra' else gt.upper() if gt.lower() != 'gt' else 'gt'
    nn_folder_name = f"desra_heatmaps_gt{nn_gt_name}_rfbicubic" if args.dataset in {'desra', 'desra-post'} else  f"heatmaps_gt{nn_gt_name}_rfbicubic"
    
    from pipeline.metrics.SAN.mask import SAN_mask_zeroing
    def postprocess(method):
        def wrapper(hr_path, sr_path, rf_path):
            res = method(hr_path, sr_path, rf_path)
            if gauss == 'gauss':
                res = gaussian_filter(res, sigma=3)
            return res
        if SAN == 'SAN':
            return SAN_mask_zeroing(wrapper)
        else:
            return wrapper
    


    if args.dataset == 'test':
        conf_dict = get_conf_dict('gt_conf.csv')
        img_pathes_test = get_img_pathes_gt(TEST_DATASET, conf_dict, subset='test', gt=gt)
    elif args.dataset == 'desra':
        conf_dict = get_conf_dict('gt_conf_desra.csv')
        img_pathes_test = get_desra_dataset(DESRA_DATASET, rf_name=gt, mask_suffix='dist')
    else:
        raise NotImplementedError()

    if args.g05:
        img_pathes_test = {k : v for k, v in img_pathes_test.items() if conf_dict[k]['ratio_bi'] >= 0.5}

    
    def npz_method(npz_path, hr_path, sr_path, rf_path):
        npz_heatmap_path = os.path.join(npz_path, Path(sr_path).stem + '.npy.gz')
        f = gzip.GzipFile(npz_heatmap_path, "r")
        res = np.load(file=f)
        f.close()
        res = res.squeeze()
        return res

    
    methods_dct = {
        'dists' : partial(npz_method, os.path.join(HEATMAPS_PATH, rf'{nn_folder_name}\dists')),
        'LDL' : partial(npz_method, os.path.join(HEATMAPS_PATH, rf'{nn_folder_name}\LDL')),
        'bd_jup' : partial(npz_method, os.path.join(HEATMAPS_PATH, rf'{nn_folder_name}\bd_jup')),
        'lpips' : partial(npz_method, os.path.join(HEATMAPS_PATH, rf'{nn_folder_name}\lpips')),
        'erqa' : partial(npz_method, os.path.join(HEATMAPS_PATH, rf'{nn_folder_name}\erqa')),

        'ssim' : partial(npz_method, os.path.join(HEATMAPS_PATH, rf'{nn_folder_name}\ssim')),
        'ssm_jup' : partial(npz_method, os.path.join(HEATMAPS_PATH, rf'{nn_folder_name}\ssm_jup')),
        'desra' : partial(npz_method, os.path.join(HEATMAPS_PATH, rf'{nn_folder_name}\desra')),
        "nn-20250421-gtgt-e30": partial(npz_method, os.path.join(HEATMAPS_PATH, rf'{nn_folder_name}\nn-20250421-gtgt-e30')),
        'pal4inpaint' : partial(npz_method, os.path.join(HEATMAPS_PATH, rf'{nn_folder_name}\pal4inpaint')),
        'pal4vst' : partial(npz_method, os.path.join(HEATMAPS_PATH, rf'{nn_folder_name}\pal4vst')),
        }

    threshold_dct = {
        'dists' : [0.25],
        'LDL' : [0.005],
        'bd_jup' : [0.1],
        'lpips' : [0.25],
        'erqa' : [0.55],
        'ssim' : [0.55],
        'ssm_jup' : [0.15],
        'desra' : [0.3],
        "nn-20250421-gtgt-e30": [0.15, 0.3],
        'pal4inpaint' : [0.5], # binary
        'pal4vst' : [0.5], # binary
        }
    
    
    g05_suf = '_g05' if args.g05 else ''
    if args.dataset in {'desra', 'desra-post'}:
        sr_list = ['all', 'RealESRGAN', 'LDL',  'SwinIR']
    else:
        sr_set = {Path(x['sr_fn']).stem.rsplit('@', 1)[-1] for x in conf_dict.values()}
        sr_list = ['all', *sorted(list(sr_set))]
        
    jsonl_path = os.path.join(rf'{args.dataset}_{gt}_{SAN}_{gauss}_desra-bin{g05_suf}.jsonl')
    processed = set()
    if os.path.isfile(jsonl_path):
        with jsonlines.open(jsonl_path) as reader:
            for obj in reader:
                method_name = list(obj.keys())[0]
                threshold = obj[method_name]['all']['best_threshold_desra']
                processed.add((method_name, threshold))
                    
    with jsonlines.open(jsonl_path, 'a', flush=True) as res_file:


        for method_name, method in reversed(methods_dct.items()):
            for thres in threshold_dct[method_name]:
                if (method_name, thres) in processed:
                    print(f'Skip {method_name} (t={thres})')
                    continue
                print(f"Start {method_name} (t={thres})")
                result = {method_name : get_best_threshold_fscore(
                    postprocess(method),
                    #postprocess(method),
                    img_pathes_test,
                    conf_dict,
                    thresholds=[thres],
                    min_area=MIN_AREA,
                    sr_list=sr_list
                    )
                    }
                print(result)
                res_file.write(result)
        
                print()
    


