import csv
import sys
import random
import numpy as np

def kendall_w(ratings):
	"""
	Computes Kendall's W statistic
	"""
	n = len(ratings)
	m = len(ratings[0])
	counts = [[0 for i in range(5)] for j in range(m)]
	for j in range(m):
		for item_ratings in ratings:
			r = item_ratings[j]
			counts[j][r - 1] += 1
	# print(counts)
	
	avg_ranks = [[0 for i in range(5)] for j in range(m)]
	for j in range(m):
		total = 0
		for i in range(5):
			prev_total = total
			total += counts[j][i]
			avg_ranks[j][i] = (total + prev_total + 1)/2
	# print(avg_ranks)
	
	ranks = [[0 for j in range(m)] for i in range(n)]
	for i, item_ratings in enumerate(ratings):
		for j in range(m):
			r = item_ratings[j]
			ranks[i][j] = avg_ranks[j][r - 1]
	
	tie_correction = [sum(c*c*c - c for c in counts[j]) for j in range(m)]
	# print(tie_correction)
	
	total_ranks = [sum(rank) for rank in ranks]
	total_ranks_sq = sum(r*r for r in total_ranks)
	return (12*total_ranks_sq - 3*m*m*n*(n + 1)*(n + 1)) / (m*m*n*(n*n - 1) - m*sum(tie_correction))

def fleiss_kappa(ratings):
	"""
	Computes Fleiss' Kappa statistic
	"""
	N = len(ratings)
	n = len(ratings[0])
	k = 5
	
	nij = [[0 for j in range(k)] for i in range(N)]
	for i, item_ratings in enumerate(ratings):
		for r in item_ratings:
			nij[i][r - 1] += 1
	
	p = [1/(N*n)*sum(nij[i][j] for i in range(N)) for j in range(k)]
	# print(sum(p)) # NOTE: Should be 1
	
	P = [1/(n*(n - 1))*sum(nij[i][j]*(nij[i][j] - 1) for j in range(k)) for i in range(N)]
	
	P_avg = 1/N*sum(P)
	
	Pe = sum(p[j]*p[j] for j in range(k))
	
	return (P_avg - Pe)/(1 - Pe)

def inter_rater_reliability(in_csv_filename, num_judges, from_end=False, random_seed=None):
	if random_seed is not None:
		random.seed(random_seed)
	
	with open(in_csv_filename, 'r', newline='', encoding='utf-8') as in_file:
		reader = csv.reader(in_file)
		header = next(reader) # Skip header
		rows = [row for row in reader]
	
	ratings_col = -1
	
	for col, heading in enumerate(header):
		if heading == 'ratings':
			ratings_col = col
	
	if ratings_col == -1:
		print(f"No column labelled 'ratings' found. Cannot read data.")
		return 0
	
	min_ratings = min(len(row) - ratings_col for row in rows)
	if num_judges > min_ratings:
		print(f"Some pair only has {min_ratings} ratings. Cannot compute kendall_w for {num_judges} judges.")
		return 0
	
	if random_seed is None:
		if from_end:
			ratings = [list(map(int, row[len(row) - num_judges:])) for row in rows]
		else:
			ratings = [list(map(int, row[ratings_col:ratings_col+num_judges])) for row in rows]
	else:
		ratings = []
		for row in rows:
			ratings_list = list(map(int, row[ratings_col:]))
			random.shuffle(ratings_list)
			ratings_list = sorted(ratings_list[:num_judges])
			ratings.append(ratings_list)
	
	# ratings = [r for r in ratings if sum(r) > num_judges]
	
	# ratings = [[1, 1, 1, 2],[2, 2, 2, 3],[3, 3, 3, 1]]
	
	# ratings = [[r if r < 5 else 4 for r in row] for row in ratings]
	ratings = [[r if r < 5 else r - 1 for r in row] for row in ratings]
	
	# print("Kendall's W:  ", kendall_w(ratings))
	# print("Fleiss' Kappa:", fleiss_kappa(ratings))
	
	# print("Number of ratings:", len(ratings))
	
	return kendall_w(ratings), fleiss_kappa(ratings)

if __name__ == "__main__":
	if len(sys.argv) == 2:
		w, k = inter_rater_reliability(sys.argv[1], 3, random_seed=1729)
		print("W:    ", w)
		print("kappa:", k)
	elif len(sys.argv) == 3:
		num_runs = int(sys.argv[2])
		W = []
		K = []
		for offset in range(num_runs):
			w, k = inter_rater_reliability(sys.argv[1], 3, random_seed=1729+offset)
			W.append(w)
			K.append(k)
		print("W mean:", np.mean(W), "std:", np.std(W))
		print("K mean:", np.mean(K), "std:", np.std(K))
	else:
		print("Usage: python " + sys.argv[0] + " <in_csv_filename> [<num_runs>]")
