# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import os
import pdb
import random
import time
from CCMARL_utils_grid import *

#state: s1*16+s2 where s1, s2\in \{0,1,...,15\} denote the positions of the agents 1 and 2.
#action: (a1,a2) where a1, a2\in \{0,1,2,3\} (up, down, left, right) denote the actions of the agents 1 and 2.
#reward: -1 for every step
#safety score: -1 for every collision step. 


#Transition kernel P
w=4   #width
h=4   #height
P_1agent=np.zeros((w*h,4),dtype=int) #P_1agent[s,a]=s' denotes the deterministic state transitiion for each one agent.
for x in range(w):
    for y in range(h):
        s=x+w*y
        if y==h-1:  #up
            P_1agent[s,0]=s
        else:
            P_1agent[s,0]=s+w
        
        if y==0:    #down
            P_1agent[s,1]=s
        else:
            P_1agent[s,1]=s-w
        
        if x==0:  #left
            P_1agent[s,2]=s
        else:
            P_1agent[s,2]=s-1
            
        if x==w-1:  #right
            P_1agent[s,3]=s
        else:
            P_1agent[s,3]=s+1
target=11
P_1agent[target]=target

num_states=(w*h)**2
transP=np.zeros((num_states,4,4,num_states))
for s1 in range(w*h):
    for s2 in range(w*h):
        s=s1*w*h+s2
        for a1 in range(4):
            s1_next=P_1agent[s1,a1]
            for a2 in range(4):
                s2_next=P_1agent[s2,a2]
                s_next=s1_next*w*h+s2_next
                transP[s,a1,a2,s_next]=1.0

#Reward r
reward=np.zeros((num_states,4,4,num_states,2,2))
for s1 in range(w*h):
    for s2 in range(w*h):
        s=s1*w*h+s2
        if s1!=target:
            reward[s,:,:,:,0,0]=-1.0  #reward r_0^{(m)}=-1 for agent m moving at non-target positions
            for a1 in range(4):
                if P_1agent[s1,a1]==s1:
                    reward[s,a1,:,:,0,0]=-5.0   #reward r_0^{(m)}=-5 for agent m staying at non-target positions
        if s2!=target:
            reward[s,:,:,:,1,0]=-1.0
            for a2 in range(4):
                if P_1agent[s2,a2]==s1:
                    reward[s,:,a2,:,0,0]=-5.0 
            if s1==s2:  
                reward[s,:,:,:,:,1]=-1.0  #safty score r1=-1 for both agents colliciding at non-target position        

#Optimal determistic path
s_opt=[51, 39, 107, 171, 187]
s1_opt=[3, 2, 6, 10, 11]
s2_opt=[3, 7, 11, 11, 11]
V0rho_opt=-1-0.9-(0.9**2+0.9**3)/2   #-2.6695
V1rho_opt=-1

#Environment setup
rho=np.zeros(num_states)
init_position=3
init_s=init_position*(w*h+1)
rho[init_s]=1.0
W=np.ones([2,2])/2
gamma=0.9
env_dict=env_setup(seed_init=1,state_space=range(num_states),action_spaces=[range(4),range(4)],\
                   rho=rho,transP=transP,reward=reward,xi=[-1.0],gamma=gamma)

#Initilize policy
set_seed(1)
pi0=[np.zeros((env_dict['num_states'],env_dict['num_actions'][m]))/env_dict['num_actions'][m] for m in range(env_dict['num_agents'])]
for m in range(env_dict['num_agents']):
    for s in range(env_dict['num_states']):
        pi0[m][s]=np.random.dirichlet([1]*env_dict['num_actions'][m],1)

#Run algorithms
T=100
print("Begin our Primal Dual algorithm")
PD_dict_exact=PrimalDual_population(env_dict,unconstrained_alg="exact",T=T,T_Viter=50,alpha_pi=None,beta_lambda=1.0,\
                  lambda_kmax=None,pi0=pi0,is_print=True,is_save=True,save_folder="results/PrimalDual_population")

print("\n\n Begin our Primal algorithm")
primal_dict_exact=primal_population(env_dict,T=T,alpha=1.0,tol=0.001,n_between_evals=1,pi0=pi0,is_print=True,\
                  is_save=True,save_folder="results/primal_population")

print("\n\n Begin CNAC algorithm")
PD_dict_inexact=PrimalDual_population(env_dict,unconstrained_alg="inexact",T=T,T_Viter=None,alpha_pi=0.05,beta_lambda=1.0,\
                  lambda_kmax=None,pi0=pi0,is_print=True,is_save=True,save_folder="results/CNAC_population")
    
#Plot results
label_size=14
lgd_size=12
folder='results/'
if not os.path.isdir(folder):
    os.makedirs(folder)

for k in [0,1]:
    plt.figure()
    plt.scatter(range(T), PD_dict_exact['Vk_rho'][k][0:T],color="red",s=5,label="Our Primal-Dual Algorithm")
    plt.plot(range(T), primal_dict_exact['Vk_rho'][k][0:T],color="black",label="Our Primal Algorithm")
    plt.plot(range(T), PD_dict_inexact['Vk_rho'][k][0:T],color="blue",linestyle="-.",label="CNAC Algorithm")
    if k==1:
        a=env_dict['xi'][0]
        plt.plot([0,T],[a,a],color="cyan",linestyle=":",label="Threshold "+r'$\xi_1=$'+str(a))
        a=-1.0
        plt.plot([0,T],[a,a],color="green",linestyle=":",label=r'$\min_{\pi} V_1(\pi)=-1$')
        plt.ylabel(r'$V_1(\pi_t)$',fontsize=label_size)
    else: 
        a=-2.6695
        plt.plot([0,T],[a,a],color="green",linestyle=":",label=r'$V_0=$'+str(a)+" of the optimal path.")
        plt.ylabel(r'$V_0(\pi_t)$',fontsize=label_size)

    plt.xlabel(r'$t$',fontsize=label_size)
    plt.legend(fontsize=lgd_size,loc=4)
    plt.savefig('results/grid_V'+str(k)+'.png',dpi=200)
    

# #Find best value
# def find_best_value(Vk_rho):
#     T=Vk_rho.shape[1]
#     indexes=np.where(Vk_rho[1]>=env_dict['xi'][0])[0]
#     opt_index=np.argmax(Vk_rho[0][indexes])
#     opt_index=indexes[opt_index]
#     return opt_index, Vk_rho[:,opt_index]

# opt_index,Vk_rho=find_best_value(PD_dict_exact['Vk_rho'])
# print("Primal Dual algorithm uses "+str(opt_index+1)\
#       +" iterations to reach the best solution along its trajectory, which has V0="\
#       +str(Vk_rho[0])+", V1="+str(Vk_rho[1]))
    
# opt_index,Vk_rho=find_best_value(primal_dict_exact['Vk_rho'])
# print("Primal algorithm uses "+str(opt_index+1)\
#       +" iterations to reach the best solution along its trajectory, which has V0="\
#       +str(Vk_rho[0])+", V1="+str(Vk_rho[1]))

# opt_index,Vk_rho=find_best_value(PD_dict_inexact['Vk_rho'])
# print("CNAC algorithm uses "+str(opt_index+1)\
#       +" iterations to reach the best solution along its trajectory, which has V0="\
#       +str(Vk_rho[0])+", V1="+str(Vk_rho[1]))


# def find_path(pim,init_s=51,final_s=187):
#     s_now=init_s
#     s1_now=int(s_now/(w*h))
#     s2_now=s_now-s1_now*w*h
#     s=[s_now]
#     s1=[s1_now]
#     s2=[s2_now]
#     t=0
#     while s_now!=final_s and t<15:
#         a1=np.where(pim[0][s_now]>0.5)[0][0]
#         a2=np.where(pim[1][s_now]>0.5)[0][0]
#         s_now=np.where(env_dict['transP'][s_now,a1,a2]>0.5)[0][0]
#         s1_now=int(s_now/(w*h))
#         s2_now=s_now-s1_now*w*h
#         s+=[s_now]
#         s1+=[s1_now]
#         s2+=[s2_now]
#         print("t="+str(t)+": s="+str(s_now)+": s1="+str(s1_now)+": s2="+str(s2_now))
#         t+=1
#     return s, s1, s2

# find_path(PD_dict_exact['pim'][-1])







    