import pickle as pkl
import numpy as np
import matplotlib.pyplot as plt
import statistics

directory = "/aux/USER/fmri-datasets/predication/processed/experiments"
subjects = 12
HRFdelay = 3
layers = 1
hidden = 10
folds = 8

concept_files = ["Dan",
                 "Scott",
                 "pick-up",
                 "put-down",
                 "briefcase",
                 "chair",
                 "Dan-pick-up1",
                 "Dan-put-down1",
                 "Scott-pick-up1",
                 "Scott-put-down1",
                 "Dan-briefcase1",
                 "Dan-chair1",
                 "Scott-briefcase1",
                 "Scott-chair1",
                 "pick-up-briefcase1",
                 "pick-up-chair1",
                 "put-down-briefcase1",
                 "put-down-chair1"]

def confusion_matrix(detections, threshold):
    tp = 0
    fp = 0
    fn = 0
    tn = 0
    for confidence, target in detections:
        if confidence>=threshold:
            if target==1:
                tp += 1
            else:
                fp += 1
        else:
            if target==1:
                fn += 1
            else:
                tn += 1
    return tp, fp, fn, tn

# http://en.wikipedia.org/wiki/Matthews_correlation_coefficient

def recall(tp, fp, fn, tn):
    # sensitivity, hit rate, true positive rate
    if tp==0 and fn==0:
        return 1
    else:
        return float(tp)/(tp+fn)

def true_negative_rate(tp, fp, fn, tn):
    # specificity, selectivity
    return float(tn)/(tn+fp)

def precision(tp, fp, fn, tn):
    # positive predictive value
    if tp==0 and fp==0:
        return 1
    else:
        return float(tp)/(tp+fp)

def negative_predictive_value(tp, fp, fn, tn):
    return float(tn)/(tn+fn)

def false_negative_rate(tp, fp, fn, tn):
    # miss rate
    return float(fn)/(fn+tp)

def false_positive_rate(tp, fp, fn, tn):
    # fall-out
    if fp==0 and tn==0:
        return 1
    else:
        return float(fp)/(fp+tn)

def false_discovery_rate(tp, fp, fn, tn):
    return float(fp)/(fp+tp)

def accuracy(tp, fp, fn, tn):
    return float(tp+tn)/(tp+tn+fp+fn)

def f1_score(tp, fp, fn, tn):
    if tp==0 and fp==0 and fn==0:
        return 0
    else:
        return float(2*tp)/(2*tp+fp+fn)

def mcc(tp, fp, fn, tn):
    tp = float(tp)
    fp = float(fp)
    fn = float(fn)
    tn = float(tn)
    denominator = np.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn))
    if denominator==0:
        return 0
    else:
        return (tp*tn-fp*fn)/denominator

def score(measure, threshold, detections):
    tp, fp, fn, tn = confusion_matrix(detections, threshold)
    return measure(tp, fp, fn, tn)

def scores(measure1, measure2, detections):
    thresholds = list(map(lambda detection: detection[0], detections))
    thresholds.sort()
    result = []
    for lower, upper in zip(thresholds[:-1], thresholds[1:]):
        threshold = (lower+upper)/2.0
        result.append((score(measure1, threshold, detections),
                       score(measure2, threshold, detections)))
    result.sort()
    result = [(0.0, 0.0)]+result+[(1.0, 1.0)] # hardwired to ROC
    return result

def auc(detections):
    points = scores(false_positive_rate, recall, detections)
    area = 0.0
    for p1, p2 in zip(points[:-1], points[1:]):
        area += ((p1[1]+p2[1])/2)*(p2[0]-p1[0])
    return area

def pkl_filename(subject, concept, modality, kind, identifier):
    return "%s/subject-%02d/%s/%s-%s-%s-%d-%s-%s.pkl"%(
        directory,
        subject+1,
        modality,
        concept,
        kind,
        identifier,
        HRFdelay,
        "nn-%d-%d"%(layers, hidden),
        "whole-brain")

def pooled_pkl_filename(concept, modality, kind, identifier):
    return "%s/pooled-subject/%s/%s-%s-%s-%d-%s-%s.pkl"%(
        directory,
        modality,
        concept,
        kind,
        identifier,
        HRFdelay,
        "nn-%d-%d"%(layers, hidden),
        "whole-brain")

def mean_auc(modality, concept_file, kind, identifier):
    aucs = []
    for subject in range(subjects):
        detections = pkl.load(
            open(pkl_filename(subject, concept_file, modality, kind, identifier),
                 "rb"))
        aucs.append(auc(detections))
    return statistics.mean(aucs), statistics.stdev(aucs)

def pooled_subject_mean_auc(modality, concept_file, identifier):
    detections = pkl.load(
        open(pooled_pkl_filename(
            concept_file, modality, "pooled-subject", identifier), "rb"))
    aucs = []
    for fold in range(folds):
        aucs.append(auc(detections[fold*768:(fold+1)*768]))
    return statistics.mean(aucs), statistics.stdev(aucs)

def pooled_subject_cross_modal_mean_auc(modality, concept_file, identifier):
    detections = pkl.load(
        open(pooled_pkl_filename(
            concept_file, modality, "pooled-subject-cross-modal", identifier),
             "rb"))
    return auc(detections), 0

def plot_auc(modality, kind, identifier, title):
    plt.figure(dpi = 300)
    x = np.arange(len(concept_files))
    y = [mean_auc(modality, concept_file, kind, identifier)[0]
         for concept_file in concept_files]
    yerr = [mean_auc(modality, concept_file, kind, identifier)[1]
            for concept_file in concept_files]
    stars = [significant(kind, modality, concept_file)
             for concept_file in concept_files]
    width = 0.75
    ax = plt.gca()
    rects = plt.bar(x,
                    y,
                    width,
                    yerr = yerr,
                    color = "blue",
                    error_kw = {"ecolor": "orange",
                                "elinewidth": 0.5,
                                "capsize": 2.5,
                                "capthick": 0.5})
    for rect, star in zip(rects, stars):
        height = rect.get_height()
        ax.text(rect.get_x()+rect.get_width()/2,
                height,
                "*" if star else "",
                fontsize = 8,
                color = "green",
                ha = "center",
                va = "bottom")
    chance = 0.5
    plt.plot(range(-1, len(concept_files)+1),
             (len(concept_files)+2)*[chance],
             color = "red",
             linestyle = "solid",
             label = "chance")
    ax.set_ylabel("AUC")
    ax.set_title(title)
    ax.set_xticks(x, concept_files)
    ax.set_xticklabels([munge(concept_file).replace("-", " ")
                        for concept_file in concept_files],
                       rotation = 45,
                       fontsize = 5,
                       ha = "right",
                       va = "top")
    ax.set_ylim(0, 1.2)
    plt.savefig("%s-%s.png"%(kind, modality), bbox_inches = "tight")

def plot_pooled_subject_auc(modality, identifier):
    plt.figure(dpi = 300)
    x = np.arange(len(concept_files))
    y = [pooled_subject_mean_auc(modality, concept_file, identifier)[0]
         for concept_file in concept_files]
    yerr = [pooled_subject_mean_auc(modality, concept_file, identifier)[1]
            for concept_file in concept_files]
    stars = [significant("pooled-subject", modality, concept_file)
             for concept_file in concept_files]
    width = 0.75
    ax = plt.gca()
    rects = plt.bar(x,
                    y,
                    width,
                    yerr = yerr,
                    color = "blue",
                    error_kw = {"ecolor": "orange",
                                "elinewidth": 0.5,
                                "capsize": 2.5,
                                "capthick": 0.5})
    for rect, star in zip(rects, stars):
        height = rect.get_height()
        ax.text(rect.get_x()+rect.get_width()/2,
                height,
                "*" if star else "",
                fontsize = 8,
                color = "green",
                ha = "center",
                va = "bottom")
    chance = 0.5
    plt.plot(range(-1, len(concept_files)+1),
             (len(concept_files)+2)*[chance],
             color = "red",
             linestyle = "solid",
             label = "chance")
    ax.set_ylabel("AUC")
    ax.set_title("Pooled Subject %s"%modality.capitalize())
    ax.set_xticks(x, concept_files)
    ax.set_xticklabels([munge(concept_file).replace("-", " ")
                        for concept_file in concept_files],
                       rotation = 45,
                       fontsize = 5,
                       ha = "right",
                       va = "top")
    ax.set_ylim(0, 1.2)
    plt.savefig("pooled-subject-%s.png"%modality, bbox_inches = "tight")

def plot_pooled_subject_cross_modal_auc(modality, identifier):
    plt.figure(dpi = 300)
    x = np.arange(len(concept_files))
    y = [pooled_subject_cross_modal_mean_auc(
        modality, concept_file, identifier)[0]
         for concept_file in concept_files]
    yerr = [pooled_subject_cross_modal_mean_auc(
        modality, concept_file, identifier)[1]
            for concept_file in concept_files]
    stars = [mann_whitney_significant("pooled-subject-cross-modal", modality, concept_file)
             for concept_file in concept_files]
    width = 0.75
    ax = plt.gca()
    rects = plt.bar(x,
                    y,
                    width,
                    yerr = yerr,
                    color = "blue",
                    error_kw = {"ecolor": "orange",
                                "elinewidth": 0.5,
                                "capsize": 2.5,
                                "capthick": 0.5})
    for rect, star in zip(rects, stars):
        height = rect.get_height()
        ax.text(rect.get_x()+rect.get_width()/2,
                height,
                "*" if star else "",
                fontsize = 8,
                color = "green",
                ha = "center",
                va = "bottom")
    chance = 0.5
    plt.plot(range(-1, len(concept_files)+1),
             (len(concept_files)+2)*[chance],
             color = "red",
             linestyle = "solid",
             label = "chance")
    ax.set_ylabel("AUC")
    ax.set_title("Pooled Subject Cross Modal %s"%modality.capitalize())
    ax.set_xticks(x, concept_files)
    ax.set_xticklabels([munge(concept_file).replace("-", " ")
                        for concept_file in concept_files],
                       rotation = 45,
                       fontsize = 5,
                       ha = "right",
                       va = "top")
    ax.set_ylim(0, 1.2)
    plt.savefig("pooled-subject-cross-modal-%s.png"%modality,
                bbox_inches = "tight")

def significant(kind, modality, concept_file):
    return "*" in open("auc-txt/%s-%s-%s-auc.txt"%(kind, modality, concept_file),
                       "r").readlines()[0]

def munge(concept_file):
    if concept_file[-1]=="1":
        return concept_file[:-1]
    else:
        return concept_file

for modality in ("video", "text"):
    for concept_file in concept_files:
        print("pooled-subject %s %s %.6f"%(
            modality,
            munge(concept_file),
            pooled_subject_mean_auc(
                modality, concept_file, None)[0]))
for modality in ("video", "text"):
    for concept_file in concept_files:
        print("cross-subject %s %s %.6f"%(
            modality,
            munge(concept_file),
            mean_auc(modality, concept_file, "cross-subject", None)[0]))
for modality in ("video", "text"):
    for concept_file in concept_files:
        print("pooled-subject-cross-modal %s %s %.6f"%(
            modality,
            munge(concept_file),
            pooled_subject_cross_modal_mean_auc(
                modality, concept_file, None)[0]))
for modality in ("video", "text"):
    plot_pooled_subject_auc(modality, "iclr2024")
    plot_auc(modality,
             "cross-subject",
             "iclr2024",
             "Cross Subject %s"%modality.capitalize())
    plot_pooled_subject_cross_modal_auc(modality, "iclr2024")
