import numpy as np
import random
import csv
from copy import deepcopy
import heapq
import math
import matplotlib.pyplot as plt
import seaborn as sns


m=5
num_repetitions = 10

def set_params(dataset):
    input_file = csv.DictReader(open(dataset))

    data = []
    for row in input_file: 
        data.append(row)

    forecasters_sort = sorted(data, key = lambda i: i['user_id'])
    teams_sort       = sorted(data, key = lambda i: i['team1'])

    diff_forecasters = []
    prev             = '' 
    for row in forecasters_sort: 
        if row['user_id'] != prev:
            diff_forecasters.append(row['user_id']) # names of diff forecasters
            prev = row['user_id']

    print("Num of diff forecasters is: %d"%len(diff_forecasters))

    num_teams = 0
    prev      = ''
    for row in teams_sort:
        if row['team1'] != prev:
            prev = row['team1']
            num_teams += 1

    print("Num of diff teams is %d"%num_teams)

    num_matches = 0
    against     = []

    for row in teams_sort:
        match = {'week' : row['week'], 'team1': row['team1'], 'team2' : row['team2']}
        if match not in against:
            against.append(match)
            num_matches += 1

    print("Num of diff nfl games is: %d"%num_matches)

    num_preds = [0]*len(diff_forecasters)
    prev      = ''

    for row in forecasters_sort:
        if row['user_id']!=prev:
            index = int(row['user_id'])
            num_preds[index] += 1
            prev = row['user_id']
        else:
            num_preds[int(prev)] += 1

    print("Number of forecasters reporting on all nfl games.")
    print(sum([1 if num_preds[i] == 284 else 0 for i in range(len(diff_forecasters))]))

    experts = []
    prev    = 0
    for i in range(len(num_preds)):
        if num_preds[i] == 284:
            experts.append(i)

    return (teams_sort,experts)

def set_experts(teams_sort,experts,num_experts,sample_id):
    # Choose K experts at random, among the ones who report for all 284 matches
    np.random.seed()
    random.seed(sample_id)
    elf_experts = random.sample(experts, num_experts)
    print("Sampled Experts")
    print(elf_experts)

    num_matches = 0
    against     = []
    outcomes    = []

    for row in teams_sort:
        match = {'week' : row['week'], 'team1': row['team1'], 'team2' : row['team2']}
        if match not in against:
            against.append(match)
            num_matches += 1
            outcomes.append(float(row['game_outcome']))

    T = num_matches
    experts_reports = [[0 for _ in range(T)] for _ in range(num_experts)]
            
    for row in teams_sort:
        match = {'week' : row['week'], 'team1': row['team1'], 'team2' : row['team2']}
        game_index = against.index(match)
        if int(row['user_id']) in elf_experts:
            expert_index = elf_experts.index(int(row['user_id']))
            experts_reports[expert_index][game_index]  = float(row['user_prob']) 
    return (num_experts, outcomes, experts_reports, T)

def regret(loss_lst, num_experts, algo_loss, T):
    loss_per_expert = []
    for i in range(num_experts):
        s = 0
        for t in range(T + 1):
            s += loss_lst[t][i]
        loss_per_expert.append(s)
    heapq.heapify(loss_per_expert)

    tot_algo_loss = (1/m)*sum([sum(algo_loss[i]) for i in range(m)])
    min_loss_hindsight = 0
    for _ in range(m):
        min_loss_hindsight+=heapq.heappop(loss_per_expert)
    min_loss_hindsight/=m
    print ("Algorithm's Loss: %f"%(tot_algo_loss))
    print ("Best fixed: %f"%min_loss_hindsight)
    print ("Regret:%f"%(tot_algo_loss - min_loss_hindsight))
    
    # returns (regret to best, best_fixed loss)
    return (tot_algo_loss - min_loss_hindsight, min_loss_hindsight) 

def draw(probs_lst, num_experts):
    np.random.seed()
    t           = np.random.uniform(0,1)
    cumulative  = 0.0
    for i in range(num_experts):
        cumulative += probs_lst[i]
        if cumulative > t:
            return i
    return (num_experts-1)

def wsu_compute(exp_chosen, j, wagers, experts_reps, num_experts, outcomes, t):
    new_lst = []
    matrix  = [wagers[i]*(1.0 - (outcomes[t] - experts_reps[i][t])**2) for i in range(num_experts)]
    for i in range(j):
        matrix[exp_chosen[i]]=0
    tot     = sum(matrix)
    for i in range(num_experts): 
        if i in exp_chosen[:j]:
            new_lst.append(wagers[i]*(1 - tot))
        else:
            new_lst.append(wagers[i]*(1 + 1 - (outcomes[t] - experts_reps[i][t])**2 - tot))
        
    return new_lst 

def main_wsu(num_experts, outcomes, experts_reports, T):
    wsu_probs   = [[1.0/num_experts]*num_experts for _ in range(m)]
    experts_loss_lst = [[] for _ in range(T)]
    wsu_loss    = [[0]*T for _ in range(m)]
    wsu_rep_regr = []
    best_fixed_loss  = []
    eta = np.sqrt(1.0*np.log(num_experts)/(1.0*T))
    for t in range(T):
        print ("Timestep t=%d for WSU"%t)
        exp_chosen=[]
        for j in range(m):
            tp=draw(wsu_probs[j], num_experts)
            while tp in exp_chosen:
                tp=draw(wsu_probs[j], num_experts)
            exp_chosen.append(tp)
        experts_loss_lst[t] = [(outcomes[t] - experts_reports[i][t])**2 for i in range(num_experts)] 
        for j in range(m):
            wsu_loss[j][t] = experts_loss_lst[t][exp_chosen[j]]        

        # probs update through wsu
        for j in range(m):
            cpy  = deepcopy(wsu_probs[j])
            temp = wsu_compute(exp_chosen, j, wsu_probs[j], experts_reports, num_experts, outcomes, t)
            wsu_probs[j] = [eta*temp[i] + (1.0 - eta)*cpy[i] for i in range(num_experts)]

        (regr_best, best_fixed) = regret(experts_loss_lst, num_experts, wsu_loss, t)

        wsu_rep_regr.append(regr_best)
        best_fixed_loss.append(best_fixed)

    return (wsu_rep_regr, best_fixed_loss)

def ftpl(num_experts, outcomes, experts_reports, T):
    experts_loss_lst = [[] for _ in range(T)]
    ftpl_loss    = [[0]*T for _ in range(m)]
    best_fixed_loss  = []
    ftpl_rep_regr = []
    eta = np.sqrt(T/(np.sqrt(2)*math.log(num_experts*math.e/m)))
    incent=2/(eta-2)
    np.random.seed()
    noisy_loss_per_expert = eta*np.random.laplace(size=num_experts)
    for t in range(T):
        print ("Timestep t=%d for FTPL"%t)
        exp_chosen=[]
        tp=[(noisy_loss_per_expert[i],i) for i in range(num_experts)]
        heapq.heapify(tp)
        for _ in range(m):
            exp_chosen.append(heapq.heappop(tp)[1])
        experts_loss_lst[t] = [(outcomes[t] - experts_reports[i][t])**2 for i in range(num_experts)] 
        for j in range(m):
            ftpl_loss[j][t] += experts_loss_lst[t][exp_chosen[j]]
        for i in range(num_experts):
            an=np.random.uniform(-incent,incent)
            pn=min(max(experts_reports[i][t]+an,0),1)
            noisy_loss_per_expert[i]+=(outcomes[t]-pn)**2
        (regr_best, best_fixed) = regret(experts_loss_lst, num_experts, ftpl_loss, t)

        ftpl_rep_regr.append(regr_best)
        best_fixed_loss.append(best_fixed)

    return (ftpl_rep_regr, best_fixed_loss)

(teams_sort,experts)=set_params('raw_user_forecasts.csv')
K=20
wsu_regr=np.zeros((50,284))
ftpl_regr=np.zeros((50,284))
for id in range(5):
    (num_experts, outcomes, experts_reports, T)=set_experts(teams_sort,experts,K,id)
    for r in range(num_repetitions):
        (wsu_rep_regr, best_fixed_loss)=main_wsu(num_experts,outcomes,experts_reports,T)
        (ftpl_rep_regr, best_fixed_loss_2)=ftpl(num_experts,outcomes,experts_reports,T)
        for t in range(T):
            wsu_regr[id*10+r][t]=wsu_rep_regr[t]
            ftpl_regr[id*10+r][t]=ftpl_rep_regr[t]
            
K=100
wsu_regr2=np.zeros((50,284))
ftpl_regr2=np.zeros((50,284))
for id in range(5):
    (num_experts, outcomes, experts_reports, T)=set_experts(teams_sort,experts,K,id)
    for r in range(num_repetitions):
        (wsu_rep_regr, best_fixed_loss)=main_wsu(num_experts,outcomes,experts_reports,T)
        (ftpl_rep_regr, best_fixed_loss_2)=ftpl(num_experts,outcomes,experts_reports,T)
        for t in range(T):
            wsu_regr2[id*10+r][t]=wsu_rep_regr[t]
            ftpl_regr2[id*10+r][t]=ftpl_rep_regr[t]
wsu_regr_mean=np.mean(wsu_regr,axis=0)
ftpl_regr_mean=np.mean(ftpl_regr,axis=0)
wsu_regr2_mean=np.mean(wsu_regr2,axis=0)
ftpl_regr2_mean=np.mean(ftpl_regr2,axis=0)
wsu_regr_lo=np.quantile(wsu_regr,q=0.2,axis=0)
ftpl_regr_lo=np.quantile(ftpl_regr,q=0.2,axis=0)
wsu_regr2_lo=np.quantile(wsu_regr2,q=0.2,axis=0)
ftpl_regr2_lo=np.quantile(ftpl_regr2,q=0.2,axis=0)
wsu_regr_hi=np.quantile(wsu_regr,q=0.8,axis=0)
ftpl_regr_hi=np.quantile(ftpl_regr,q=0.8,axis=0)
wsu_regr2_hi=np.quantile(wsu_regr2,q=0.8,axis=0)
ftpl_regr2_hi=np.quantile(ftpl_regr2,q=0.8,axis=0)
font = {'family' : 'normal',
        'size'   : 13,
        'weight': 'bold'
        }
plt.rc('font', **font)
plt.plot(wsu_regr_mean, label = 'ODG', color = 'green')
plt.plot(ftpl_regr_mean, label = 'FTPL', color = 'red')
plt.fill_between(np.arange(284), wsu_regr_lo, wsu_regr_hi, color = 'green', alpha = 0.3)
plt.fill_between(np.arange(284), ftpl_regr_lo, ftpl_regr_hi, color = 'red', alpha = 0.3)
plt.legend()
plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
plt.xlabel("Number of Rounds T", weight = 'bold')
plt.ylabel("Regret", weight = 'bold')
plt.savefig("K20", bbox_inches = "tight")
plt.plot(wsu_regr2_mean, label = 'ODG', color = 'green')
plt.plot(ftpl_regr2_mean, label = 'FTPL', color = 'red')
plt.fill_between(np.arange(284), wsu_regr2_lo, wsu_regr2_hi, color = 'green', alpha = 0.3)
plt.fill_between(np.arange(284), ftpl_regr2_lo, ftpl_regr2_hi, color = 'red', alpha = 0.3)
plt.legend()
plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
plt.xlabel("Number of Rounds T", weight = 'bold')
plt.ylabel("Regret", weight = 'bold')
plt.savefig("K100", bbox_inches = "tight")