import csv
import sys
from scipy.stats import kendalltau
import numpy as np
import random

random.seed(1729)

def kendalltau_individual(ratings, predictions, variant='c'):
	individual_predictions = []
	individual_consensus_ratings = []
	for r, prediction in zip(ratings, predictions):
		individual_ratings = r
		individual_consensus_ratings.extend(individual_ratings)
		individual_predictions.extend(prediction for x in individual_ratings)
	if variant != 'b':
		return kendalltau(individual_predictions, individual_consensus_ratings, variant=variant)
	else:
		return kendalltau(individual_predictions, individual_consensus_ratings)

def kendall_tau_ratings_subset(in_csv_filename, index, limit=None):
	with open(in_csv_filename, 'r', newline='', encoding='utf-8') as in_file:
		reader = csv.reader(in_file)
		header = next(reader) # Skip header
		rows = [row for row in reader]
	
	ratings_col = -1
	
	for col, heading in enumerate(header):
		if heading == 'ratings':
			ratings_col = col
			
	all_ratings = []
	for row in rows:
		ratings_list = list(map(int, row[ratings_col:]))
		if limit is not None:
			random.shuffle(ratings_list)
			ratings_list = sorted(ratings_list[:limit])
		all_ratings.append(ratings_list)
	
	predictions = [ratings[index] for ratings in all_ratings]
	if limit is not None:
		ratings = [r[:index] + r[index + 1:limit] if index < limit else r[:index] for r in all_ratings]
	else:
		ratings = [r[:index] + r[index + 1:] if index < len(r) - 1 else r[:index] for r in all_ratings]
	
	return kendalltau_individual(ratings, predictions).correlation

if __name__ == "__main__":
	if len(sys.argv) == 3:
		print(kendall_tau_ratings_subset(sys.argv[1], int(sys.argv[2])))
	elif len(sys.argv) == 4:
		print(kendall_tau_ratings_subset(sys.argv[1], int(sys.argv[2]), int(sys.argv[3])))
	elif len(sys.argv) == 5:
		num_trials = int(sys.argv[4])
		results = []
		for i in range(num_trials):
			results.append(kendall_tau_ratings_subset(sys.argv[1], int(sys.argv[2]), int(sys.argv[3])))
		print(f"Mean: {100*np.mean(results):.1f}")
		print(f"STD:  {100*np.std(results):.2f}")
	else:
		print("Usage: python " + sys.argv[0] + " <in_csv_filename> <expert_index> [<limit>] [<num_trials>]")
