import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import sklearn.metrics as sk
import os
import csv
import json


def print_measures(log, auroc, fpr, method_name='Ours', recall_level=0.95):
    if log is None:
        print('FPR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fpr))
        print('AUROC: \t\t\t{:.2f}'.format(100 * auroc))
    else:
        log.debug('\t\t\t\t' + method_name)
        log.debug('  FPR{:d} AUROC'.format(int(100*recall_level)))
        log.debug('& {:.2f} & {:.2f}'.format(100*fpr, 100*auroc))


def stable_cumsum(arr, rtol=1e-05, atol=1e-08):
    """Use high precision for cumsum and check that final value matches sum
    Parameters
    ----------
    arr : array-like
        To be cumulatively summed as flat
    rtol : float
        Relative tolerance, see ``np.allclose``
    atol : float
        Absolute tolerance, see ``np.allclose``
    """
    out = np.cumsum(arr, dtype=np.float64)
    expected = np.sum(arr, dtype=np.float64)
    if not np.allclose(out[-1], expected, rtol=rtol, atol=atol):
        raise RuntimeError('cumsum was found to be unstable: '
                           'its last element does not correspond to sum')
    return out


def fpr_and_fdr_at_recall(y_true, y_score, recall_level=0.95, pos_label=None):
    classes = np.unique(y_true)
    if (pos_label is None and
            not (np.array_equal(classes, [0, 1]) or
                     np.array_equal(classes, [-1, 1]) or
                     np.array_equal(classes, [0]) or
                     np.array_equal(classes, [-1]) or
                     np.array_equal(classes, [1]))):
        raise ValueError("Data is not binary and pos_label is not specified")
    elif pos_label is None:
        pos_label = 1.

    # make y_true a boolean vector
    y_true = (y_true == pos_label)

    # sort scores and corresponding truth values
    desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]
    y_score = y_score[desc_score_indices]
    y_true = y_true[desc_score_indices]

    # y_score typically has many tied values. Here we extract
    # the indices associated with the distinct values. We also
    # concatenate a value for the end of the curve.
    distinct_value_indices = np.where(np.diff(y_score))[0]
    threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]

    # accumulate the true positives with decreasing threshold
    tps = stable_cumsum(y_true)[threshold_idxs]
    fps = 1 + threshold_idxs - tps      # add one because of zero-based indexing

    thresholds = y_score[threshold_idxs]

    recall = tps / tps[-1]

    last_ind = tps.searchsorted(tps[-1])
    sl = slice(last_ind, None, -1)      # [last_ind::-1]
    recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl]

    cutoff = np.argmin(np.abs(recall - recall_level))

    return fps[cutoff] / (np.sum(np.logical_not(y_true)))   # , fps[cutoff]/(fps[cutoff] + tps[cutoff])


def get_measures(_pos, _neg, recall_level=0.95):
    pos = np.array(_pos[:]).reshape((-1, 1))
    neg = np.array(_neg[:]).reshape((-1, 1))
    examples = np.squeeze(np.vstack((pos, neg)))
    labels = np.zeros(len(examples), dtype=np.int32)
    labels[:len(pos)] += 1

    auroc = sk.roc_auc_score(labels, examples)
    fpr = fpr_and_fdr_at_recall(labels, examples, recall_level)

    return auroc, fpr


def get_and_print_results(args, in_score, out_score, auroc_list, fpr_list):
    '''
    1) evaluate detection performance for a given OOD test set (loader)
    2) print results (FPR95, AUROC)
    '''
    aurocs, fprs = [], []
    measures = get_measures(-in_score, -out_score)
    aurocs.append(measures[0]); fprs.append(measures[1])
    print(f'in score samples (random sampled): {in_score[:3]}, out score samples: {out_score[:3]}')

    auroc = np.mean(aurocs); fpr = np.mean(fprs)
    auroc_list.append(auroc); fpr_list.append(fpr)  # used to calculate the avg over multiple OOD test sets
    print("FPR:{}, AUROC:{}".format(fpr, auroc))
    
    return {'auroc': auroc, 'fpr95': fpr}


def add_results(results_data, method, results, out_dataset):
    results_data.append({
        'Dataset': out_dataset,
        'Method': method,
        'AUROC': results['auroc'],
        'FPR95': results['fpr95']
    })
    return results_data


def add_overall_results(results_data, method, auroc_list, fpr_list):
    """
    Add overall average results to results_data
    
    Args:
        results_data (list): List of dictionaries containing evaluation results
        auroc_list, fpr_list: Lists containing MCM scores for all datasets
        auroc_list_gl, fpr_list_gl: Lists containing GL-MCM scores for all datasets
    """
    
    # Calculate overall averages for MCM
    avg_auroc = np.mean(auroc_list)
    avg_fpr = np.mean(fpr_list)
    
    # Add MCM overall results
    results_data.append({
        'Dataset': 'Overall',
        'Method': method,
        'AUROC': avg_auroc,
        'FPR95': avg_fpr
    })
    

    # Print overall results
    print("\n" + "="*50)
    print("OVERALL RESULTS")
    print("="*50)
    print(f"{method}    - AUROC: {avg_auroc:.4f}, FPR95: {avg_fpr:.4f}")
    print("="*50)
    
    return results_data


def save_results_to_csv(results_data, output_dir):
    csv_path = os.path.join(output_dir, "results.csv")
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(results_data[0].keys())
        for row in results_data:
            writer.writerow(row.values())


def save_results_to_json(results_data, output_dir, filename):
    json_path = os.path.join(output_dir, filename)
    with open(json_path, "w") as f:
        json.dump(results_data, f)