import numpy as np
import matplotlib.pyplot as plt
import os
import pdb
import random
import time

def primal_eg1(T=100,alpha=0.001,p0=0.4,q0=0.4,eta=-0.006):
    p=np.zeros(T)
    q=np.zeros(T)
    p[0]=p0
    q[0]=q0
    for t in range(T-1):
        if p[t]*q[t]<1/16-eta/2:
            Qp=-q[t]
            Qq=-p[t]
        elif (1-p[t])*(1-q[t])<1/16-eta/2:
            Qp=1-q[t]
            Qq=1-p[t]
        else: 
            Qp=1-2*q[t]
            Qq=1-2*p[t]
        p[t+1]=p[t]/(p[t]+(1-p[t])*np.exp(alpha*Qp))
        q[t+1]=q[t]/(q[t]+(1-q[t])*np.exp(alpha*Qq))
    return p,q

def primal_eg2(T=100,alpha=0.001,p0=0.4,q0=0.4,eta=0):
    p=np.zeros(T)
    q=np.zeros(T)
    p[0]=p0
    q[0]=q0
    for t in range(T-1):
        if p[t]*q[t]+(1-p[t])*(1-q[t])<0.9-eta/2:
            Qp=1-2*q[t]
            Qq=1-2*p[t]
        else: 
            Qp=-q[t]
            Qq=-p[t]
        p[t+1]=p[t]/(p[t]+(1-p[t])*np.exp(alpha*Qp))
        q[t+1]=q[t]/(q[t]+(1-q[t])*np.exp(alpha*Qq))
    return p,q

def PD_eg1(T=100,alpha=0.1,beta=0.1,lambda0=[0,0]):
    p=np.zeros(T)
    q=np.zeros(T)
    lambdas=np.zeros((T,2))
    lambdas[0]=lambda0
    for t in range(T-1):
        if lambdas[t,1]>lambdas[t,0]:
            p[t]=0
            q[t]=0
        else:
            p[t]=1
            q[t]=1
        lambdas[t+1,0]=max(lambdas[t,0]-beta*(2*p[t]*q[t]-1/8),0)
        lambdas[t+1,1]=max(lambdas[t,1]-beta*(2*(1-p[t])*(1-q[t])-1/8),0)
    return p,q,lambdas


def PD_eg2(T=100,alpha=0.1,beta=0.1,lambda0=0):
    p=np.zeros(T)
    q=np.zeros(T)
    lambdas=np.zeros(T+1)
    lambdas[0]=lambda0
    for t in range(T):
        p[t]=1
        q[t]=1
        lambdas[t+1]=max(lambdas[t]-beta*(2*p[t]*q[t]+2*(1-p[t])*(1-q[t])-1.8),0)
    return p,q,lambdas

colors=['red','black','blue','cyan','brown']

#Primal Algorith for Example 1
T=20000
alpha=0.001
eta=-0.006
p0_list=[0.45,0.2,0.3,0.25,0.35]
q0_list=[0.3,0.3,0.3,0.25,0.35]
N=len(p0_list)
p_list=[0]*N
q_list=[0]*N
V0_list=[0]*N
V1_list=[0]*N
V2_list=[0]*N
for i in range(N):
    p_list[i],q_list[i]=primal_eg1(T,alpha,p0=p0_list[i],q0=q0_list[i],eta=eta)
    V1_list[i]=2*p_list[i]*q_list[i]
    V2_list[i]=2*(1-p_list[i])*(1-q_list[i])
    V0_list[i]=V1_list[i]+V2_list[i]

label_size=14
lgd_size=12

folder='results/'
if not os.path.isdir(folder):
    os.makedirs(folder)
plt.figure()
for i in range(N):
    plt.plot(range(T),p_list[i],color=colors[i],label=f'$(p_0,q_0)={p0_list[i],q0_list[i]}$')
plt.plot([0,T-1],[0.25,0.25],color='green',linestyle='dotted',label=r'$p^*=\frac{1}{4}$')
plt.xlabel(r'$t$',fontsize=label_size)
plt.ylabel(r'$p_t$',fontsize=label_size)
plt.legend(fontsize=lgd_size)
plt.savefig(folder+'p_eg1.png',dpi=200)

plt.figure()
for i in range(N):
    plt.plot(range(T),q_list[i],color=colors[i],label=f'$(p_0,q_0)={p0_list[i],q0_list[i]}$')
plt.plot([0,T-1],[0.25,0.25],color='green',linestyle='dotted',label=r'$q^*=\frac{1}{4}$')
plt.xlabel(r'$t$',fontsize=label_size)
plt.ylabel(r'$q_t$',fontsize=label_size)
plt.legend(fontsize=lgd_size)
plt.savefig(folder+'q_eg1.png',dpi=200)

plt.figure()
for i in range(N):
    plt.plot(range(T),V0_list[i],color=colors[i],label=f'$(p_0,q_0)={p0_list[i],q0_list[i]}$')
plt.plot([0,T-1],[5/4,5/4],color='green',linestyle='dotted',label=r'$V_0(\pi^*)=\frac{5}{4}$')
plt.xlabel(r'$t$',fontsize=label_size)
plt.ylabel(r'$V_0(\pi_t)$',fontsize=label_size)
plt.legend(fontsize=lgd_size)
plt.savefig(folder+'V0_eg1.png',dpi=200)

plt.figure()
for i in range(N):
    plt.plot(range(T),V1_list[i],color=colors[i],label=f'$(p_0,q_0)={p0_list[i],q0_list[i]}$')
plt.plot([0,T-1],[1/8,1/8],color='green',linestyle='dotted',label='safety threshold: '+r'$\xi_1=\frac{1}{8}$')
plt.xlabel(r'$t$',fontsize=label_size)
plt.ylabel(r'$V_1(\pi_t)$',fontsize=label_size)
plt.legend(fontsize=lgd_size)
plt.savefig(folder+'V1_eg1.png',dpi=200)
    
plt.figure()
for i in range(N):
    plt.plot(range(T),V2_list[i],color=colors[i],label=f'$(p_0,q_0)={p0_list[i],q0_list[i]}$')
plt.plot([0,T-1],[1/8,1/8],color='green',linestyle='dotted',label='safety threshold: '+r'$\xi_2=\frac{1}{8}$')
plt.xlabel(r'$t$',fontsize=label_size)
plt.ylabel(r'$V_2(\pi_t)$',fontsize=label_size)
plt.legend(fontsize=lgd_size)
plt.savefig(folder+'V2_eg1.png',dpi=200)


#Primal Algorith for Example 2
T=100
alpha=0.1
p0_list=[0,0.2,0.4,0.7,0.9]
q0_list=[1-p for p in p0_list]
N=len(p0_list)
p_list=[0]*N
q_list=[0]*N
V0_list=[0]*N
V1_list=[0]*N
V2_list=[0]*N
for i in range(N):
    p_list[i],q_list[i]=primal_eg2(T,alpha,p0=p0_list[i],q0=q0_list[i])
    V0_list[i]=2*p_list[i]*q_list[i]
    V1_list[i]=2*(1-p_list[i])*(1-q_list[i])+V0_list[i]

label_size=14
lgd_size=12

plt.figure()
for i in range(N):
    p0=np.round(p0_list[i],1)
    q0=np.round(q0_list[i],1)
    plt.plot(range(T),V1_list[i],color=colors[i],label=f'$(p_0,q_0)={p0,q0}$')
plt.plot([0,T-1],[1.8,1.8],color='green',linestyle='dotted',label='safety threshold: '+r'$\xi_1=1.8$')
plt.xlabel(r'$t$',fontsize=label_size)
plt.ylabel(r'$V_1(\pi_t)$',fontsize=label_size)
plt.legend(fontsize=lgd_size,loc=4)
plt.savefig(folder+'V1_eg2.png',dpi=200)


#Primal-Dual Algorithms
T=100
alpha=0.1
p_PD1,q_PD1,lambda1=PD_eg1(T,alpha,beta=0.1,lambda0=[0,0])
p_PD2,q_PD2,lambda2=PD_eg2(T,alpha,beta=0.1,lambda0=0)
print("p_t of primal-dual for example 1: ")
print(p_PD1)
print("\n q_t of primal-dual for example 1: ")
print(q_PD1)
print("\n p_t of primal-dual for example 2: ")
print(p_PD2)
print("\n q_t of primal-dual for example 2: ")
print(q_PD2)



