
from __future__ import division
import numpy as np
import pylab as pl
import matplotlib.pyplot as plt
from random import random
import pandas as pd
import random
import random
import heapq
from tqdm import tqdm
import scipy.stats as st
import seaborn as sns
from typing import Any
from numpy import *

N = 2
M = 1
beta = 1
state = [0] * N
belief = [0.5] * N
reward = [0] * N
action = [0] * N
total_reward = []

l = []
for i in range(N):
    l.append(i)
# print(l)
A = set()
B = set()
C = []

value_max = 0

T_max = 50000
R1 = []
R2 = []
x = 0
y = 0
z = 0
w = 0


epison_1=0.1
epison_2=0.1


a = st.beta.rvs(1, 1, size=1)
a = (1-epison_1)*a[0]+epison_1

b = st.beta.rvs(1, 1, size=1)
b = (1-epison_2)*b[0]+epison_2


transition_active = np.array([[a, 1 - a], [1 - a, a]])
transition_passive = np.array([[a, 1 - a], [1 - a, a]])

transitionName_active = [["00", "01"], ["10", "11"]]
transitionMatrix_active = [[a, 1 - a], [1 - a, a]]

transitionName_passive = [["00", "01"], ["10", "11"]]
transitionMatrix_passive = [[a, 1 - a], [1 - a, a]]
q1 = np.array([[a, 1 - a], [1 - a, a]])
q2 = np.array([[a, 1 - a], [1 - a, a]])

value_list_active = [10, 20]
probability_active = [1 - b, b]
value_list_passive = [-10, 10]
probability_passive = [b, 1 - b]

r1 = np.array([[b, 1 - b]])
r2 = np.array([[1 - b, b]])

lists_action = [[] for _ in range(N)]
for i in range(len(lists_action)):
    lists_action[i] = []

lists_state = [[] for _ in range(N)]
for i in range(len(lists_state)):
    lists_state[i] = []

lists_action_state = [[] for _ in range(N)]
for i in range(len(lists_action_state)):
    lists_action_state[i] = [0 for _ in range(4)]
    # print(lists_action_state[i])

Q_value = np.array([[0, 0], [0, 0]], dtype='float32')
value_function = np.array([[0, 0]], dtype='float32')

alpha_value = [[] for _ in range(N)]
for i in range(len(alpha_value)):
    alpha_value[i] = [0 for _ in range(4)]

# value_function= [[] for _ in range(N)]
# for i in range(len(value_function)):
#     value_function[i] = [0 for _ in range(2)]

whittle_index = [[] for _ in range(T_max + 1)]
whittle_index[0] = [0] * N
for i in range(1, len(whittle_index)):
    whittle_index[i] = []


def activity_forecast(state,events, transition, days):
    activityToday = state
    transitionName = events
    transitionMatrix = transition
    activityList = [activityToday]
    i = 0
    while i != days:
        if activityToday == 0:
            change = np.random.choice(transitionName[0], replace=True, p=transitionMatrix[0])
            if change == "00":
                activityList.append(0)
            else:
                activityList.append(1)
        else:
            change = np.random.choice(transitionName[1], replace=True, p=transitionMatrix[1])
            if change == "10":
                activityList.append(0)
            else:
                activityList.append(1)
        i += 1
    return activityList[-1]


def calculation_value_function(q1, q2, r1, r2, value):
    # threshold=0.01
    # inter_max_num=1000
    # t=0
    beta = 0.9
    # diff1=0.02
    # diff2=0.02
    result = 0

    # value_state0=[]
    # value_state1=[]
    q_value = np.array([[0, 0], [0, 0]], dtype='float32')
    value_state = value
    Q1 = q1
    Q2 = q2
    R1 = r1
    R2 = r2

    # while diff1>threshold or diff2>threshold  :
    # while t<1000  :
    #     print(value_state)
    q_value[0][0] = Q1[0][0] * (
                1 * R1[0][1] + R1[0][1] * beta * value_state[0][0] + (-1) * R1[0][0] + R1[0][0] * beta * value_state[0][
            0]) + \
                    Q1[0][1] * (1 * R2[0][0] + R2[0][0] * beta * value_state[0][1] + (2) * R2[0][1] + R2[0][1] * beta *
                                value_state[0][1])
    q_value[0][1] = Q2[0][0] * (
                1 * R1[0][1] + R1[0][1] * beta * value_state[0][0] + (-1) * R1[0][0] + R1[0][0] * beta * value_state[0][
            0]) + \
                    Q2[0][1] * (1 * R2[0][0] + R2[0][0] * beta * value_state[0][1] + (2) * R2[0][1] + R2[0][1] * beta *
                                value_state[0][1])
    q_value[1][0] = Q1[1][0] * (
                1 * R1[0][1] + R1[0][1] * beta * value_state[0][0] + (-1) * R1[0][0] + R1[0][0] * beta * value_state[0][
            0]) + \
                    Q1[1][1] * (1 * R2[0][0] + R2[0][0] * beta * value_state[0][1] + (2) * R2[0][1] + R2[0][1] * beta *
                                value_state[0][1])
    q_value[1][1] = Q2[1][0] * (
                1 * R1[0][1] + R1[0][1] * beta * value_state[0][0] + (-1) * R1[0][0] + R1[0][0] * beta * value_state[0][
            0]) + \
                    Q2[1][1] * (1 * R2[0][0] + R2[0][0] * beta * value_state[0][1] + (2) * R2[0][1] + R2[0][1] * beta *
                                value_state[0][1])
    # print(q_value)
    if q_value[0][0] > q_value[0][1] or q_value[0][0] == q_value[0][1]:
        value_state[0][0] = q_value[0][0]
    else:
        value_state[0][0] = q_value[0][1]
    if q_value[1][0] > q_value[1][1] or q_value[1][0] == q_value[1][1]:
        value_state[0][1] = q_value[1][0]
    else:
        value_state[0][1] = q_value[1][1]
    # print(value_state)

    # value_state0.append(value_state[0][0])
    # if len(value_state0)==1:
    #     diff1=value_state0[0]
    # else:
    #     diff1=value_state0[-1]-value_state0[-2]
    #
    # value_state1.append(value_state[0][1])
    # if len(value_state1)==1:
    #     diff1=value_state1[0]
    # else:
    #     diff1=value_state1[-1]-value_state1[-2]
    # t=t+1
    return q_value, value_state


def calculate_transitions(b):
    list4 = []
    list5 = []
    list7 = []
    list6 = []
    T = np.zeros(shape=(2, 2))
    mylist = []
    for i in range(len(b)):
        mylist.append(b[i])

    if len(mylist) >= 2:
        a = sum((1 - mylist[i]) * (1 - mylist[i + 1]) for i in range(len(mylist) - 1))
        # list4 =[(1-mylist[i]) * (1-mylist[i + 1]) for i in range(len(mylist) - 1)]

        # list5 = [(1-mylist[i]) * mylist[i + 1] for i in range(len(mylist) - 1)]
        b = sum((1 - mylist[i]) * mylist[i + 1] for i in range(len(mylist) - 1))

        # list6 = [mylist[i] * (1-mylist[i + 1]) for i in range(len(mylist) - 1)]
        c = sum(mylist[i] * (1 - mylist[i + 1]) for i in range(len(mylist) - 1))

        # list7 = [mylist[i] * mylist[i + 1] for i in range(len(mylist) - 1)]
        d = sum(mylist[i] * mylist[i + 1] for i in range(len(mylist) - 1))
        if (a + b) != 0:
            T[0][0] = a / (a + b)
            T[0][1] = b / (a + b)
        else:
            T[0][0] = 0.5
            T[0][1] = 0.5
        if (c + d) != 0:
            T[1][0] = c / (c + d)
            T[1][1] = d / (c + d)
        else:
            T[1][0] = 0.5
            T[1][1] = 0.5
    else:
        T = np.array([[0.5, 0.5], [0.5, 0.5]])

    return T


def count_state_action(state, action):
    lists_action_state = [0] * 4
    for i in range(len(action)):
        if state[i] == 0 and action[i] == 0:
            lists_action_state[0] = lists_action_state[0] + 1
        if state[i] == 0 and action[i] == 1:
            lists_action_state[1] = lists_action_state[1] + 1
        if state[i] == 1 and action[i] == 0:
            lists_action_state[2] = lists_action_state[2] + 1
        if state[i] == 1 and action[i] == 1:
            lists_action_state[3] = lists_action_state[3] + 1
    return lists_action_state


def belief_count_state(belief, action):
    lists_action_state = [0] * 4
    for i in range(len(belief)):
        if action[i] == 0:
            lists_action_state[0] = lists_action_state[0] + (1 - belief[i])
            lists_action_state[2] = lists_action_state[2] + belief[i]
        if action[i] == 1:
            lists_action_state[1] = lists_action_state[1] + (1 - belief[i])
            lists_action_state[3] = lists_action_state[3] + belief[i]
    return lists_action_state


def number_of_certain_probability(sequence, probability):
    x = np.random.uniform(0, 1)
    cumulative_probability = 0.0
    for item, item_probability in zip(sequence, probability):
        cumulative_probability += item_probability
        if x < cumulative_probability:
            break
    return item


def b_update(reward_active, reward_passive, transition1, transition2, reward, b):
    Q = reward_active
    P = reward_passive
    b_new = 0
    if reward == 10:
        sum1 = Q[0][0] * transition1[1][1] * b + Q[0][0] * transition1[0][1] * (1 - b)
        # print(sum1)
        sum2 = sum1 + P[0][1] * transition1[0][0] * (1 - b) + P[0][1] * transition1[1][0] * b
        # print(sum2)
        if sum2 != 0:
            b_new = sum1 / sum2
    if reward == 20:
        b_new = 1
    if reward == -10:
        b_new = 0
    if reward == 0:
        b_new = b * transition2[1][1] + (1 - b) * transition2[0][1]
    return b_new


def find_max_matrix(T1):
    result1 = 0
    result2 = 0
    if T1[0][0] >= T1[0][1]:
        result1 = T1[0][0]
    else:
        result1 = T1[0][1]
    if T1[1][0] >= T1[1][1]:
        result2 = T1[1][0]
    else:
        result2 = T1[1][1]
    if result1 >= result2:
        result = result1
    else:
        result = result2
    return result


value_max = [[0 for item3 in range(1)]  for item1 in range(N)]
Q_value=[]
for  i in range(N):
    Q_value.append(np.array([[0, 0], [0, 0]],dtype='float32'))
gap1=0
gap2=0


sum = 0
sum1 = 0
sum2 = []
sum_duo=[]
for i in tqdm(range(100)):
    sum = 0
    sum1 = 0
    sum2 = []
    state = [0] * N
    belief = [0.5] * N
    reward = [0] * N
    action = [0] * N
    total_reward = []
    value_max = [[0 for item3 in range(1)] for item1 in range(N)]
    Q_value = []
    for i in range(N):
        Q_value.append(np.array([[0, 0], [0, 0]], dtype='float32'))
    gap1 = 0
    gap2 = 0

    lists_action = [[] for _ in range(N)]
    for i in range(len(lists_action)):
        lists_action[i] = []

    lists_state = [[] for _ in range(N)]
    for i in range(len(lists_state)):
        lists_state[i] = []

    lists_action_state = [[] for _ in range(N)]
    for i in range(len(lists_action_state)):
        lists_action_state[i] = [0 for _ in range(4)]
        # print(lists_action_state[i])

    # Q_value = np.array([[0, 0], [0, 0]], dtype='float32')
    value_function = np.array([[0, 0]], dtype='float32')

    alpha_value = [[] for _ in range(N)]
    for i in range(len(alpha_value)):
        alpha_value[i] = [0 for _ in range(4)]

    for t in range(T_max):
        sum = 0
        sum1 = 0
        sum2 = []
        action = [0] * N
        epsilon = N / (N + t)
        # access=[0]
        # print("time is")
        # print(t)
        # print("whittle index is")
        # print(whittle_index[t])
        if (np.random.uniform() < epsilon):
            # print(True)
            temp=random.random()
            if temp>0.5:
                action = [1,0]
            else:
                action = [0, 1]
            # access = random.sample(range(0, N), M)
            # access = random.randrange(0,2)
        else:
            # print(False)
            if whittle_index[t][0]>whittle_index[t][1]:
                action=[1,0]
            else:
                action=[0,1]
            # access = list(map(whittle_index[t].index, heapq.nlargest(M, whittle_index[t])))
            # access = access[0]
        # print(access)
        # print(type(access))
        # print("state is:")
        # print(state)
        # A = set(access)
        # B = set(l)
        # C = list(A ^ B)

        # for i in range(len(access)):
        #     action[access[i]] = 1
        # for i in range(len(C)):
        #     action[C[i]] = 0
        # print("action is:")
        # print(action)

        for i in range(N):
            if action[i] == 1:
                state[i] = activity_forecast(state[i],transitionName_active, transitionMatrix_active, 1)
            else:
                state[i] = activity_forecast(state[i],transitionName_passive, transitionMatrix_passive, 1)
        # print("new state is:")
        # print(state)

        for i in range(N):
            if action[i] == 1 and state[i] == 1:
                reward[i] = number_of_certain_probability(value_list_active, probability_active)
            if action[i] == 1 and state[i] == 0:
                reward[i] = number_of_certain_probability(value_list_passive, probability_passive)
            if action[i] == 0:
                reward[i] = 0
        # print("reward is :")
        # print(reward)


        for i in range(len(reward)):
            belief[i] = b_update(r2, r1, q1, q2, reward[i], belief[i])
        # print(belief)

        for i in range(len(belief)):
            if belief[i]>=0.5:
                belief[i]=1
            else:
                belief[i]=0

        # print(belief)

        for i in range(N):
            lists_action[i].append(action[i])
            lists_state[i].append(belief[i])
            # print(lists_state[i],lists_action[i])
            lists_action_state[i]=count_state_action(lists_state[i],lists_action[i])
            # lists_action_state[i] = belief_count_state(lists_state[i], lists_action[i])

            # print(lists_action_state[i])


        for i in range(N):

            # Q_value,value_function=calculation_value_function(q1,q2,r2,r1,value_function)
            # print(Q_value)
            # print(value_function)
            # if value_function[0][0]>=value_function[0][1]:
            #     value_max=value_function[0][0]
            # else:
            #     value_max=value_function[0][1]
            # print(value_max)
            # print(Q_value)
            value_max[i] = find_max_matrix(Q_value[i])
            # print(value_max)
            if  action[i] == 0:
                alpha_value[i][0] = 1 / (lists_action_state[i][0] + 1)
                # Q_value[i][0][0] = (1-belief[i])*((1 - alpha_value[i][0]) * Q_value[i][0][0] + alpha_value[i][0] * (reward[i] + value_max[i]))
                Q_value[i][0][0] = (1 - alpha_value[i][0]) * Q_value[i][0][0] + alpha_value[i][0] * (reward[i] + value_max[i])

                alpha_value[i][2] = 1 / (lists_action_state[i][2] + 1)
                # Q_value[i][1][0] = belief[i] * ((1 - alpha_value[i][2]) * Q_value[i][1][0] + alpha_value[i][2] * (reward[i] + value_max[i]))
                Q_value[i][1][0] =  (1 - alpha_value[i][2]) * Q_value[i][1][0] + alpha_value[i][2] * (reward[i] + value_max[i])

            if  action[i] == 1:
                alpha_value[i][1] = 1 / (lists_action_state[i][1] + 1)
                # Q_value[i][0][1] = (1-belief[i])*((1 - alpha_value[i][1]) * Q_value[i][0][1] + alpha_value[i][1] * (reward[i] + value_max[i]))
                Q_value[i][0][1] = (1)*((1 - alpha_value[i][1]) * Q_value[i][0][1] + alpha_value[i][1] * (reward[i] + value_max[i]))

                alpha_value[i][3] = 1 / (lists_action_state[i][3] + 1)
                # Q_value[i][1][1] = belief[i]*((1 - alpha_value[i][3]) * Q_value[i][1][1] + alpha_value[i][3] * (reward[i] + value_max[i]))
                Q_value[i][1][1] = ((1 - alpha_value[i][3]) * Q_value[i][1][1] + alpha_value[i][3] * (reward[i] + value_max[i]))

            # gap1=belief[i]*(Q_value[i][1][1] - Q_value[i][1][0])
            gap1=(Q_value[i][1][1] - Q_value[i][1][0])
            # gap2=(1-belief[i])*(Q_value[i][0][1] - Q_value[i][0][0])
            gap2=(Q_value[i][0][1] - Q_value[i][0][0])
            if gap1>=gap2:
                whittle_index[t + 1].append(gap1)
            else:
                whittle_index[t + 1].append(gap2)

        # print(whittle_index[t + 1])
        # print("\n")
        for ele in range(0, len(reward)):
            sum = sum + reward[ele]
        total_reward.append(sum)

    sum_duo.append(total_reward)
    # print(total_reward)

# print(sum_duo)

for i in range(len(sum_duo)):
    temp = 0
    for j in range(len(sum_duo[i])):
        temp = temp + np.power(beta, j) * sum_duo[i][j]
        sum_duo[i][j] = temp

final = [0] * T_max
for i in range(T_max):
    for j in range(len(sum_duo)):
        final[i] = final[i] + sum_duo[j][i]
    final[i] = final[i] / len(sum_duo)

file2 = open("Q-learning.txt", 'w')
for i in range(len(final)):
    file2.write(str(final[i]) + '\n')
file2.close()

for i in range(len(final)):
    final[i]=final[i]/(i+1)
# print(final)


print(final[-5:-1])