#include <math.h>
#include <stdlib.h>
#include <unistd.h>
#include <stdio.h>
#include <string.h>
#include "auc-iclr2024.h"

int tp(int n, double confidence[n], int target[n], double threshold) {
  int count = 0;
  for (int i = 0; i<n; i++) {
    if (confidence[i]>=threshold&&target[i]==1) count++;
  }
  return count;
}

int fp(int n, double confidence[n], int target[n], double threshold) {
  int count = 0;
  for (int i = 0; i<n; i++) {
    if (confidence[i]>=threshold&&target[i]==0) count++;
  }
  return count;
}

int fn(int n, double confidence[n], int target[n], double threshold) {
  int count = 0;
  for (int i = 0; i<n; i++) {
    if (confidence[i]<threshold&&target[i]==1) count++;
  }
  return count;
}

int tn(int n, double confidence[n], int target[n], double threshold) {
  int count = 0;
  for (int i = 0; i<n; i++) {
    if (confidence[i]<threshold&&target[i]==0) count++;
  }
  return count;
}

// http://en.wikipedia.org/wiki/Matthews_correlation_coefficient

double recall(int tp, int fp, int fn, int tn) {
  // sensitivity, hit rate, true positive rate
  if (tp==0&&fn==0) return 1.0;
  return ((double)tp)/(tp+fn);
}

double true_negative_rate(int tp, int fp, int fn, int tn) {
  // specificity, selectivity
  return ((double)tn)/(tn+fp);
}

double precision(int tp, int fp, int fn, int tn) {
  // positive predictive value
  if (tp==0&&fp==0) return 1.0;
  return ((double)tp)/(tp+fp);
}

double negative_predictive_value(int tp, int fp, int fn, int tn) {
  return ((double)tn)/(tn+fn);
}

double false_negative_rate(int tp, int fp, int fn, int tn) {
  // miss rate
  return ((double)fn)/(fn+tp);
}

double false_positive_rate(int tp, int fp, int fn, int tn) {
  // fall-out
  return ((double)fp)/(fp+tn);
}

double false_discovery_rate(int tp, int fp, int fn, int tn) {
  return ((double)fp)/(fp+tp);
}

double accuracy(int tp, int fp, int fn, int tn) {
  return ((double)(tp+tn))/(tp+tn+fp+fn);
}

double f1_score(int tp, int fp, int fn, int tn) {
  // sensitivity, hit rate, true positive rate
  if (tp==0&&fp==0&&fn==0) return 0.0;
  return ((double)(2*tp))/(2*tp+fp+fn);
}

double mcc(int tp, int fp, int fn, int tn) {
  double dtp = tp;
  double dfp = fp;
  double dfn = fn;
  double dtn = tn;
  double denominator = sqrt((dtp+dfp)*(dtp+dfn)*(dtn+dfp)*(dtn+dfn));
  if (denominator==0.0) return 0.0;
  return (dtp*dtn-dfp*dfn)/denominator;
}

double score(double (*measure)(int, int, int, int),
	     double threshold,
	     int n,
	     double confidence[n],
	     int target[n]) {
  return (*measure)(tp(n, confidence, target, threshold),
		    fp(n, confidence, target, threshold),
		    fn(n, confidence, target, threshold),
		    tn(n, confidence, target, threshold));
}

int compare_doubles(const void *x, const void *y) {
  double xd = *(double *)x;
  double yd = *(double *)y;
  if (xd<yd) return -1;
  if (xd==yd) return 0;
  return 1;
}

int compare_points(const void *x, const void *y) {
  struct point *xp = (struct point *)x;
  struct point *yp = (struct point *)y;
  if (xp->x<yp->x) return -1;
  if (xp->x==yp->x) {
    if (xp->y<yp->y) return -1;
    if (xp->y==yp->y) return 0;
    return 1;
  }
  return 1;
}

double auc(double (*measure1)(int, int, int, int),
	   double (*measure2)(int, int, int, int),
	   int n,
	   double confidence[n],
	   int target[n]) {
  double thresholds[n];
  struct point points[n-1];
  for (int i = 0; i<n; i++) thresholds[i] = confidence[i];
  qsort(&thresholds[0], n, sizeof(double), &compare_doubles);
  for (int i = 0; i<n-1; i++) {
    double lower = thresholds[i];
    double upper = thresholds[i+1];
    double threshold = (lower+upper)/2.0;
    points[i].x = score(measure1, threshold, n, confidence, target);
    points[i].y = score(measure2, threshold, n, confidence, target);
  }
  qsort(&points[0], n-1, sizeof(struct point), &compare_points);
  double area = 0.0;
  struct point p_left = {0.0, 0.0};  // hardwired to ROC
  struct point p_right = {1.0, 1.0}; // hardwired to ROC
  // This puts a threshold before the lowest threshold.
  area += ((p_left.y+points[0].y)/2)*(points[0].x-p_left.x);
  for (int i = 0; i<n-2; i++) {
    struct point p1 = points[i];
    struct point p2 = points[i+1];
    area += ((p1.y+p2.y)/2)*(p2.x-p1.x);
  }
  // This puts a threshold above the highest threshold.
  area += ((points[n-2].y+p_right.y)/2)*(p_right.x-points[n-2].x);
  return area;
}

double mean_auc(double (*measure1)(int, int, int, int),
		double (*measure2)(int, int, int, int),
		int subjects,
		int n,
		double confidence[subjects][n],
		int target[subjects][n]) {
  double mean = 0.0;
  int count = 0;
  for (int subject = 0; subject<subjects; subject++) {
    mean += auc(measure1, measure2, n, confidence[subject], target[subject]);
    count++;
  }
  return mean/count;
}

double pooled_subject_mean_auc(double (*measure1)(int, int, int, int),
			       double (*measure2)(int, int, int, int),
			       int folds,
			       int n,
			       double confidence[folds][n],
			       int target[folds][n]) {
  double mean = 0.0;
  for (int fold = 0; fold<folds; fold++) {
    mean += auc(measure1,
		measure2,
		n,
		confidence[fold],
		target[fold]);
  }
  return mean/folds;
}

double pooled_subject_cross_modal_mean_auc
(double (*measure1)(int, int, int, int),
 double (*measure2)(int, int, int, int),
 int n,
 double confidence[n],
 int target[n]) {
  return auc(measure1, measure2, n, confidence, target);
}

double p_value_by_sampling(double (*measure1)(int, int, int, int),
			   double (*measure2)(int, int, int, int),
			   int subjects,
			   int n,
			   double confidence[subjects][n],
			   int target[subjects][n],
			   int m) {
  int count = 0;
  double base_auc = mean_auc(measure1,
			     measure2,
			     subjects,
			     n,
			     confidence,
			     target);
  for (int i = 0; i<m; i++) {
    for (int subject = 0; subject<subjects; subject++) {
      for (int k = 0; k<n; k++) {
	confidence[subject][k] = ((double)rand())/RAND_MAX;
      }
    }
    if (mean_auc(measure1,
		 measure2,
		 subjects,
		 n,
		 confidence,
		 target)>base_auc) {
      count++;
    }
  }
  return ((double)count)/m;
}

double pooled_subject_p_value_by_sampling(double (*measure1)(int, int, int, int),
					  double (*measure2)(int, int, int, int),
					  int folds,
					  int n,
					  double confidence[folds][n],
					  int target[folds][n],
					  int m) {
  int count = 0;
  double base_auc =
    pooled_subject_mean_auc(measure1, measure2, folds, n, confidence, target);
  for (int i = 0; i<m; i++) {
    for (int fold = 0; fold<folds; fold++) {
      for (int k = 0; k<n; k++) {
	confidence[fold][k] = ((double)rand())/RAND_MAX;
      }
    }
    if (pooled_subject_mean_auc(measure1,
				measure2,
				folds,
				n,
				confidence,
				target)>base_auc) {
      count++;
    }
  }
  return ((double)count)/m;
}

double pooled_subject_cross_modal_p_value_by_sampling
(double (*measure1)(int, int, int, int),
 double (*measure2)(int, int, int, int),
 int n,
 double confidence[n],
 int target[n],
 int m) {
  int count = 0;
  double base_auc =
    pooled_subject_cross_modal_mean_auc
    (measure1, measure2, n, confidence, target);
  for (int i = 0; i<m; i++) {
    for (int k = 0; k<n; k++) {
      confidence[k] = ((double)rand())/RAND_MAX;
    }
    if (pooled_subject_cross_modal_mean_auc(measure1,
					    measure2,
					    n,
					    confidence,
					    target)>base_auc) {
      count++;
    }
  }
  return ((double)count)/m;
}

int main(int argc, char *argv[argc]) {
  char *kind = argv[1];
  char *modality = argv[2];
  char *concept = argv[3];
  int subjects = SUBJECTS;
  int folds = FOLDS;
  int samples = SAMPLES;
  double base_auc, p_value;
  if (!strcmp(kind, "pooled-subject")) {
    int n = 512*subjects/folds;
    double confidence[folds][n];
    int target[folds][n];
    char pathname[100];
    sprintf(&pathname[0], "detections-iclr2024/%s-%s-%s.txt",
	    kind, concept, modality);
    FILE *f = fopen(pathname, "r");
    int n1;
    fscanf(f, "%d", &n1);
    if (n1!=folds*n) {
      fprintf(stderr, "wrong number of samples\n");
      exit(EXIT_FAILURE);
    }
    for (int fold = 0; fold<folds; fold++) {
      for (int k = 0; k<n; k++) {
	fscanf(f, "%lf %d", &confidence[fold][k], &target[fold][k]);
      }
    }
    fclose(f);
    base_auc = pooled_subject_mean_auc(&false_positive_rate,
				       &recall,
				       folds,
				       n,
				       confidence,
				       target);
    p_value = pooled_subject_p_value_by_sampling(&false_positive_rate,
						 &recall,
						 folds,
						 n,
						 confidence,
						 target,
						 samples);
  }
  else if (!strcmp(kind, "cross-subject")) {
    int n = 512;
    double confidence[subjects][n];
    int target[subjects][n];
    char pathname[100];
    sprintf(&pathname[0], "detections-iclr2024/%s-%s-%s.txt",
	    kind, concept, modality);
    FILE *f = fopen(pathname, "r");
    int n1;
    fscanf(f, "%d", &n1);
    if (n1!=subjects*n) {
      fprintf(stderr, "wrong number of samples\n");
      exit(EXIT_FAILURE);
    }
    for (int subject = 0; subject<subjects; subject++) {
      for (int k = 0; k<n; k++) {
	fscanf(f, "%lf %d", &confidence[subject][k], &target[subject][k]);
      }
    }
    fclose(f);
    base_auc = mean_auc(&false_positive_rate,
			&recall,
			subjects,
			n,
			confidence,
			target);
    p_value = p_value_by_sampling(&false_positive_rate,
				  &recall,
				  subjects,
				  n,
				  confidence,
				  target,
				  samples);
  }
  else if (!strcmp(kind, "pooled-subject-cross-modal")) {
    int n = 512*subjects;
    double confidence[n];
    int target[n];
    char pathname[100];
    sprintf(&pathname[0], "detections-iclr2024/%s-%s-%s.txt",
	    kind, concept, modality);
    FILE *f = fopen(pathname, "r");
    int n1;
    fscanf(f, "%d", &n1);
    if (n1!=n) {
      fprintf(stderr, "wrong number of samples\n");
      exit(EXIT_FAILURE);
    }
    for (int k = 0; k<n; k++) {
      fscanf(f, "%lf %d", &confidence[k], &target[k]);
    }
    fclose(f);
    base_auc = pooled_subject_cross_modal_mean_auc(&false_positive_rate,
						   &recall,
						   n,
						   confidence,
						   target);
    p_value = pooled_subject_cross_modal_p_value_by_sampling
      (&false_positive_rate,
       &recall,
       n,
       confidence,
       target,
       samples);
  }
  else {
    printf("unrecognized kind\n");
    return EXIT_FAILURE;
  }
  printf("%s %s %s ", kind, modality, concept);
  if (p_value<0.005) printf("%.6lf*\n", base_auc);
  else printf("%.6lf\n", base_auc);
  return EXIT_SUCCESS;
}
