import numpy as np
from scipy import optimize
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.pylab as pylab
import matplotlib

sns.set()

def f_relu(x,emp_theta_1,emp_theta_2):

  return max((emp_theta_1-x)**2,(emp_theta_2-x)**2)


def sigm(x):
  return 1/(1+np.exp(-x))


def f_sigm(x,emp_theta_1,emp_theta_2):

  g1 = (x-emp_theta_1)**2
  g2 = (x-emp_theta_2)**2

  return sigm(g1)+sigm(g2)


def get_ipr(alg, sigma_g_range, sigma, T):
    l = len(sigma_g_range)

    arr = np.zeros((T, l))

    for (i, sigma_g) in enumerate(sigma_g_range):

        print(i)

        theta_1 = 0
        theta_2 = 2 * np.sqrt(sigma_g)

        for t in range(T):

            emp_theta_1 = np.random.normal(theta_1, sigma)

            emp_theta_2 = np.random.normal(theta_2, sigma)

            min_theta = min(emp_theta_1, emp_theta_2)

            max_theta = max(emp_theta_1, emp_theta_2)

            c1 = 0
            c2 = 0

            if (alg == 'FedAvg'):
                w = (emp_theta_1 + emp_theta_2) / 2

            if (alg == 'ReLU'):
                # x_0 = np.random.uniform(min_theta-1,max_theta + 1)
                x_0 = 0

                h = optimize.minimize(f_relu, x0=x_0, args=(emp_theta_1, emp_theta_2))

                w = (h.x)[0]

            if (alg == 'Sigm'):
                # x_0 = np.random.uniform(min_theta-1,max_theta + 1)
                # x_0 = (emp_theta_1 + emp_theta_2)/2+np.random.normal(0,0.1)
                x_0 = 0

                h = optimize.minimize(f_sigm, x0=x_0, args=(emp_theta_1, emp_theta_2))

                w = (h.x)[0]

            if ((w - theta_1) ** 2 < (emp_theta_1 - theta_1) ** 2):
                c1 = 1

            if ((w - theta_2) ** 2 < (emp_theta_2 - theta_2) ** 2):
                c2 = 1

            ipr = (c1 + c2) / 2

            arr[t][i] = ipr

    ipr_avg = np.mean(arr, axis=0)

    ipr_std = np.std(arr, axis=0)

    return ipr_avg, ipr_std


sns.set_style('whitegrid')
sns.set_palette("bright")
ftsize = 30
littleft = 22
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
params = {'legend.fontsize': littleft,
          'axes.labelsize': ftsize,
          'axes.titlesize': ftsize,
          'xtick.labelsize': littleft,
          'ytick.labelsize': littleft}
pylab.rcParams.update(params)
lw = 5
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["legend.handlelength"] = 1.5

g, (ax1) = plt.subplots(1, 1, figsize=(6,5))


algs = ['FedAvg', 'ReLU', 'Sigm']
linest = ['-.', '--', '-']
colors = ['C0', 'C3', 'C2']
# colors = [ '--k', '.-r', 'g','--b']
colorid = ['k', 'r', 'g', 'b']

sigma = 1
sigma_g_range = [i for i in range(20)]

x_axis = sigma_g_range

for ii, alg in enumerate(algs):
    prs = []
    losses = []

    if (alg == 'ReLU'):
        T = 1000
    else:
        T = 10000

    # T = 10

    pr, prstd = get_ipr(alg, sigma_g_range, sigma, T)

    if (alg == 'Sigm'):
        typeis = 'IncFL'
    elif (alg == 'ReLU'):
        typeis = 'IncFL (ReLU)'
    else:
        typeis = alg

    ax1.plot(x_axis, pr, colors[ii], ls=linest[ii], linewidth = 4, label=typeis)
    ax1.fill_between(x_axis, y1=np.array(pr) - np.array(prstd) / 4,
                     y2=np.array(pr) + np.array(prstd) / 4, color=colors[ii], alpha=.05)


ax1.legend(loc=3)
ax1.set_xlabel('Data Heterogeneity    ')
ax1.set_ylabel('Inc. Participation Rate \n(IPR)', multialignment='center')
ax1.legend()
g.tight_layout()
g.savefig('fig/mean_est.pdf', bbox_inches='tight')





