import numpy as np
import pandas as pd
import math
from scipy.stats import bernoulli
from math import pi
from matplotlib import pyplot as plt


def minusList(l1, l2):
    return [x - y for x, y in zip(l1, l2)]


def addList(l1, l2):
    return [x + y for x, y in zip(l1, l2)]


def timesList(k, l):
    return [k * ele for ele in l]


def minuList(k, l):
    return [k - ele for ele in l]


Co = [2488.6939999999986, 2070.001999999998, 1731.756, 1459.6720000000005, 1240.4040000000002, 1069.6700000000005,
      924.042, 807.1679999999999, 712.424, 630.4419999999999, 563.02, 504.578, 455.4340000000001]
Cn = [2492.4179999999974, 2070.333999999998, 1735.798, 1459.9080000000001, 1242.5259999999987, 1071.231999999999,
      925.832, 808.61, 713.75, 633.0420000000004, 563.4340000000002, 506.238, 456.884]
Ct = [2498.194000000002, 2073.1419999999976, 1734.1159999999998, 1461.7480000000003, 1247.2220000000004, 1071.0839999999996,
      926.406, 807.816, 713.406, 632.8719999999995, 564.5620000000009, 507.112, 458.174]

To = [7073.747999999999, 7568.374, 7962.438, 8283.703999999989, 8541.56199999999, 8742.126, 8912.42200000001,
      9050.360000000004, 9161.999999999993, 9258.784000000001, 9337.475999999999, 9406.497999999994, 9463.39000000001]
Tn = [7068.44, 7563.146, 7958.07, 8282.747999999994, 8538.181999999997, 8739.826000000005, 8911.048000000003,
      9048.315999999999, 9160.014000000017, 9256.906000000003, 9336.495999999996, 9404.199999999997, 9462.100000000008]
Tt = [7060.638000000001, 7562.446, 7960.46, 8279.481999999989, 8531.856000000002, 8740.422000000002, 8909.588000000007,
      9049.677999999996, 9160.148000000003, 9255.383999999996, 9335.817999999988, 9403.536000000011, 9461.352000000004]

Cstdo = [6.160670756622542, 5.862378223746388, 5.857056626192219, 6.260386497844648, 7.833410411662674,
         8.583612387702017, 3.1170149527448445, 2.776567117521514, 1.8311254276247106, 1.797654143109463,
         1.548058106881434, 2.614447486282122, 1.1083981026437342]
Cstdn = [11.031362913638787, 5.890398530396933, 7.987000642542697, 4.423809802988422, 4.0299993863562955,
         2.586646258589272, 6.0736554870263495, 1.9422124156893998, 5.058313410951816, 1.7703365736542036,
         1.6843356150688504, 3.1157985957252863, 2.539004550924256]
Cstdt = [10.954309717600012, 3.551706020899733, 6.0431070721600335, 5.376283379179794, 5.297669521330031,
         3.8083884508901256, 3.353879973056355, 3.7855003573240444, 2.3677070835896585, 4.487218916928761,
         2.6498127418228687, 1.9816628022062108, 2.1828791239915826]

Tstdo = [115.33343182269398, 96.47730367293646, 85.55086297636045, 74.05187630303503, 62.75815609145954,
         52.10725596306142, 45.735893081911065, 38.972431281612394, 35.66925847280821, 29.624674580491178,
         26.96860070526463, 21.586894079510373, 19.87606349355928]
Tstdn = [114.36703371164262, 104.71974352527799, 88.6440358963873, 75.39101071082679, 63.65831348692801,
         53.776237540385814, 45.88526665499504, 38.825818008124436, 33.47485330811772, 29.109674749127652,
         26.239702437337204, 22.431406554204308, 20.894831896906947]
Tstdt = [114.83944860543347, 97.24913924554808, 83.62176989277374, 76.62541142467033, 66.21476620815028,
         52.396487630374615, 48.416826620504565, 37.04468712244713, 31.242832842109564, 27.52322793569097,
         27.062110117283908, 22.509458545242712, 20.07013811611669]

Cot = minusList(Co, Ct)
Cnt = minusList(Cn, Ct)
Tot = minusList(To, Tt)
Tnt = minusList(Tn, Tt)

Cso = timesList(1 / math.sqrt(500), Cstdo)
Csn = timesList(1 / math.sqrt(500), Cstdn)
Cst = timesList(1 / math.sqrt(500), Cstdt)
Tso = timesList(1 / math.sqrt(500), Tstdo)
Tsn = timesList(1 / math.sqrt(500), Tstdn)
Tst = timesList(1 / math.sqrt(500), Tstdt)

Cou = addList(Cot, Cso)
Col = minusList(Cot, Cso)
Cnu = addList(Cnt, Csn)
Cnl = minusList(Cnt, Csn)
Ctu = addList(Ct, Cst)
Ctl = minusList(Ct, Cst)

Tou = addList(Tot, Tso)
Tol = minusList(Tot, Tso)
Tnu = addList(Tnt, Tsn)
Tnl = minusList(Tnt, Tsn)
Ttu = addList(Tt, Tst)
Ttl = minusList(Tt, Tst)


plt.figure(figsize=(8, 6), dpi=600)
plt.grid(True)
# plt.title("Improvement from Trivial algorithm\nand Comparison with Oracle Attack (with Confidence Bar)")
# plt.text(0.08, 6, "mu1 = 0.85\nT = 10000, repeat 500 times")
plt.axhline(0, color='black', lw=2)
x = np.linspace(0.03, 0.15, 13)
plt.ylim(-10, 4)
plt.plot(x, Cot, label="Our attack", color='red', lw=2.4)
plt.plot(x, Cou, color='pink', lw=0.8)
plt.plot(x, Col, color='pink', lw=0.8)
plt.fill_between(x, Cou, Col, alpha=0.25, color='pink')

plt.plot(x, Cnt, label="Previous attack", color='blue', lw=2.4)
plt.plot(x, Cnu, color='skyblue', lw=0.8)
plt.plot(x, Cnl, color='skyblue', lw=0.8)
plt.fill_between(x, Cnu, Cnl, alpha=0.25, color='skyblue')

# plt.plot(x, Ct, label="nips", color='green', lw=1.8)
# plt.plot(x, Ctu, color='lightgreen', lw=0.8)
# plt.plot(x, Ctl, color='lightgreen', lw=0.8)
# plt.fill_between(x, Cnu, Cnl, alpha=0.25, color='lightgreen')

plt.xlabel("μ2", fontsize=28)
plt.ylabel("Cost", fontsize=28)
plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
plt.legend(fontsize=24)
plt.tick_params(labelsize=24)
plt.savefig("2arm_Cost.png", dpi=600, bbox_inches='tight')
# plt.show()

plt.figure(figsize=(8, 6), dpi=600)
plt.grid(True)
# plt.title("Improvement from Trivial algorithm\nand Comparison with Oracle Attack (with Confidence Bar)")
# plt.text(0.08, 10, "mu1 = 0.85\nT = 10000, repeat 500 times")
plt.axhline(0, color='black', lw=2)
x = np.linspace(0.03, 0.15, 13)
plt.plot(x, Tot, label="Our attack", color='red', lw=2.4)
plt.plot(x, Tou, color='pink', lw=0.8)
plt.plot(x, Tol, color='pink', lw=0.8)
plt.fill_between(x, Tou, Tol, alpha=0.25, color='pink')

plt.plot(x, Tnt, label="Previous attack", color='blue', lw=2.4)
plt.plot(x, Tnu, color='skyblue', lw=0.8)
plt.plot(x, Tnl, color='skyblue', lw=0.8)
plt.fill_between(x, Tnu, Tnl, alpha=0.25, color='skyblue')

# plt.plot(x, Tn, label="trivial", color='green', lw=1.8)
# plt.plot(x, Tnu, color='lightgreen', lw=0.8)
# plt.plot(x, Tnl, color='lightgreen', lw=0.8)
# plt.fill_between(x, Tnu, Tnl, alpha=0.25, color='lightgreen')

plt.xlabel("μ2", fontsize=28)
plt.ylabel("Chosen times", fontsize=28)
plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
plt.legend(fontsize=24)
plt.tick_params(labelsize=24)
plt.savefig("2arm_Times2.png", dpi=600, bbox_inches='tight')
# plt.show()
