import random
import numpy as np
from scipy.stats import norm 
from sklearn import datasets
import matplotlib.pyplot as plt
from sklearn import preprocessing
import sys


'''
Goal: Experiment on Incentivized Collaborative MAB

'''


def net_profit(x, mu, c0, c1, c2, kappa, kappa2, s):
	return mu[-1] - x - (c0 + c1 * norm.cdf(kappa, loc=x, scale=s) + \
		c2 * (1 - norm.cdf(kappa2, loc=x, scale=s)))


def mab_incent(mu, s, M, T, eps):

	'''
	 choose hyperparameters ε, c0, c1, kappa
	'''
	lam = 0 # the pre-specified param in defining the sys objective
	c0 = 1 # basis cost
	c1 = 5 # penalty cost for low-end arm
	c2 = -10 # reward for high-end arm
	kappa = 2 # higher more excluded arms
	kappa2 = 4 # higher less rewarded arms

	# visualize the mu and the incentive design
	plt.figure(0)
	# plt.subplot(1,2,1)
	# plt.hist(mu, bins='auto')
	# plt.title("Histogram of entities' capability mu")
	plt.subplot(1,2,1)
	plt.plot(np.arange(M), mu)
	plt.title("Entities' capability mu")
	plt.subplot(1,2,2)
	xs = np.linspace(mu[0], mu[-1], num=200)
	ys = net_profit(xs, mu, c0, c1, c2, kappa, kappa2, s)
	plt.plot(xs, ys, 'b.')
	plt.xlabel("mu")
	plt.axhline(y = 0, color = 'r', linestyle = '-')
	plt.title("Expected net profit under the incentive")
	# plt.show()
	plt.savefig('fig_mu_incent.pdf', dpi=300)  


	'''
	Run T time steps,
	each time t:
	1. each arm decides whether to par
	2. system selects a participating arm to be active arm, a[t]
	3. active arm realizes a reward z[t], each arm enjoys the reward g[i,t]=z[t] if par, and 0 ow
	4. system charged price for each arm, c[i,t]=c0 + c1*1(hist perfm) if par, and 0 ow
	5. rec system and each arm's profit/balance, s_profit[t], arm_profit[i, t]
	'''

	a = - np.ones(T, dtype=int)
	g = np.zeros((M,T))
	z = np.zeros(T) # realized rewards
	c = np.zeros((M,T))
	pfmc = - np.ones(M) * np.Inf # each arm's average realized gain up to time t
							   # will be updated each time

	s_profit = np.zeros(T) # system's cumulative profit
	s_balance = np.zeros(T) # system's balance
	arm_profit = np.zeros((M,T))
	showed_set = [] # set of arms that have showed up as active for at least once

	for t in range(T):

		# every arm decides whether to participate

		if t == 0:
			par_set = [m for m in range(M)]

		else:
			par_set = []
			pfmc_showed = pfmc[~np.isinf(pfmc)]
			for i in range(M):

				# arm decision rule whether to par
				ept_cost = c0 + c1 * norm.cdf(kappa, loc=mu[i], scale=s) + \
				c2 * (1 - norm.cdf(kappa2, loc=mu[i], scale=s))

				# make the local expectation gain simpler
				# ept_gain = (1-eps[t]) * np.max(pfmc_showed) + eps[t] * np.mean(pfmc_showed) - mu[i] 
				ept_gain = np.max(pfmc_showed) - mu[i]

				# print(f"ept_cost = {ept_cost}, ept_gain={ept_gain}")

				prob = 1/(1 + np.exp(-(ept_gain - ept_cost)))
				if np.random.uniform(0, 1, 1) < prob:
					par_set.append(i)

		print(f"par_set={par_set}")

		# debug: overwrite
		# par_set = [m for m in range(M)]
		dropout = set([m for m in range(M)]) - set(par_set)
		# print(f"The arms that drop out are {dropout}")

		# if none participates, skip this round
		if len(par_set) == 0:
			continue
		print(f"The number of arms in par_set is {len(par_set)}")

		u = np.random.uniform(0, 1, 1)
		explore = (u < eps[t])
		if explore or t==0:
			a[t] = random.choice(par_set)

		else:
			# a[t] = par_set[np.argmax(pfmc[par_set])]
			# select the arm from the pool of those who participate and were active before
			# due to our init of pfmc=-Inf, we do not need to worry about those not yet active
			a[t] = par_set[np.argmax(pfmc[par_set])]

		at = a[t]

		if at < 0:
			sys.exit("Ladybug: at should be nonnegative according to our init")

		# generate z’s from Gaussian distributions centered at the corresponding mu
		z_local = mu[par_set] + np.random.normal(0, s, len(par_set))
		z[t] = z_local[par_set.index(at)]

		# all participants enjoy the realized reward
		g[par_set,t] = z[t]

		# all participants's actual payment
		# assuming participants must realize & publicize
		c[par_set,t] = c0 + c1 * (z_local < kappa) + c2 * (z_local > kappa2)

		# print(f"system balance/income at {t} is {c[:,t].sum()}")
		s_balance[t] = c[:,t].sum()

		# all participants's final profits -- need to subtract the baseline local reward
		arm_profit[par_set,t] = z[t] - mu[par_set] - c[par_set, t]

		# system's final profit
		s_profit[t] = z[t] + lam * c[:,t].sum()
		
		# update pfmc at the index of the active arm, namely a[t]
		pfmc[at] = np.sum((a[:t+1]==at) * z[:t+1]) / np.sum(a[:t+1]==at)

	return arm_profit, s_profit, s_balance, z, a


def mab_nonincent(mu, s, M, T, eps):
	'''
	Calc non incentivized gain as baseline
	input
	output
	'''
	a = - np.ones(T, dtype=int)
	z = np.zeros(T) # realized rewards
	pfmc = np.ones(M) * np.Inf # each arm's average realized gain up to time t
							   # will be updated each time

	for t in range(T):

		u = np.random.uniform(0, 1, 1)
		explore = (u < eps[t])
		if explore:
			a[t] = random.choice([m for m in range(M)])
		else:
			a[t] = np.argmax(pfmc)
		at = a[t]

		if at < 0:
			sys.exit("Ladybug: at should be nonnegative according to our init")

		# generate z’s from Gaussian distributions centered at the corresponding mu
		z[t] = np.random.normal(mu[at], s, 1)

		# update pfmc at the index of the active arm, namely a[t]
		pfmc[at] = np.sum((a[:t+1]==at) * z[:t+1]) / np.sum(a[:t+1]==at)

	# plt.figure(3)
	# plt.hist(a, bins=M)
	# plt.title("non-incentivized arm chosen count")
	# plt.show()

	return z, np.cumsum(z)


'''
	visualize the system’s cumulative profit and the participation activities of each arm
'''

'''
Set up arms 
underlying rewards distribution assumed to be z ~ N(mu, s^2)
generate mu's in an arbitrary way and order them desc
'''

M = 50 # total arm number
T = 150 # total time steps
time = np.arange(T)
eps = 0.1 / np.ones(T) # 0.1 / np.arange(1,T+1)
_mu, _sigma = 0, 1 # how to gen mu
mu = np.exp(np.random.normal(_mu, _sigma, M))
# mu = np.random.standard_cauchy(M) * _sigma
mu = np.sort(mu) # from small to large
s = 1 # noise level of the reward conditional on the underlying mu ("capability")



# run experiments
nRep = 2 # number of replications
z_cum = np.zeros((nRep, T))
z_base_cum = np.zeros((nRep, T))
s_profit_cum = np.zeros((nRep, T))
s_balance_cum = np.zeros((nRep, T))
active_arms = np.zeros((nRep, T))
arm_profit_cum = np.zeros((nRep, M, T))
arm_profit_record = np.zeros((nRep, M, T))

for r in range(nRep):
	arm_profit, s_profit, s_balance, z, a = mab_incent(mu, s, M, T, eps)
	z_cum[r,:] = np.cumsum(z)
	_, z_base_cum[r,:] = mab_nonincent(mu, s, M, T, eps)

	arm_profit_record[r,:,:] = arm_profit
	arm_profit_cum[r,:,:] = np.cumsum(arm_profit, axis=1)
	s_profit_cum[r,:] = np.cumsum(s_profit)
	s_balance_cum[r,:] = np.cumsum(s_balance)

	# trace of active arms
	active_arms[r,:] = a




import matplotlib.pylab as pl
colors = pl.cm.jet(np.linspace(0,1,M))
 
plt.figure(1, figsize=(12, 12))
plt.subplot(2,2,1)
# plt.plot(time, np.mean(s_profit_cum,axis=0))
plt.plot(time, np.mean(s_balance_cum,axis=0))
plt.title("System cummulative balance")
plt.subplot(2,2,2)

# Commented: cumulative rewards
# for i in range(M):
# 	plt.plot(np.mean(arm_profit_cum,axis=0)[i,:], color=colors[i])
# plt.legend([f'{i}' for i in range(M)], loc='upper center', bbox_to_anchor=(0.5, 1.2),
#           ncol=5, fancybox=True, shadow=True)
# plt.title("Arm's cummulative profit")

data = [np.mean(arm_profit_cum,axis=0)[i,:] for i in range(M)]
plt.boxplot(data)
M10 = np.floor(M/10)
plt.xticks([i for i in range(M)], [f'{int(i//M10 * M10)}' if i%M10==0 else '' for i in range(M)], rotation=-15)
plt.title("Arm's round-wise profit")

plt.subplot(2,2,3)
# plt.plot(time, a)
# plt.title("Active arm")
# plt.hist(a, bins=M)
# plt.title("Arm chosen count")
# plt.plot(time, np.mean(arm_profit_cum,axis=0)[-5,:],'g-', label='best arm')
# plt.plot(time, np.mean(arm_profit_cum,axis=0)[25,:],'y-.', label='medium arm')
# plt.plot(time, np.mean(arm_profit_cum,axis=0)[0,:],'r-.', label='worst arm')
# plt.title("Particular arms' cummulative profit")
plt.plot(time, np.mean(z_cum,axis=0), label="incentivized")
plt.plot(time, np.mean(z_base_cum,axis=0), label="non-incentivized")
plt.title("System's cummulative realized gains")
plt.legend()
plt.subplot(2,2,4)
plt.scatter(time, np.mean(active_arms,axis=0), label="trace of active arm")
plt.legend()
# plt.show()
plt.savefig('fig_comparison.pdf', dpi=300)  


