from turtle import color
import numpy as np
from collections import defaultdict
from matplotlib import pyplot as plt
from matplotlib import font_manager as fm, rcParams
from matplotlib import rc
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
import os
import pandas as pd
import seaborn as sns
import argparse
import imageio


s = 20
rc_ = {'figure.figsize':(8,8),'axes.labelsize': 30, 'xtick.labelsize': s, 
        'ytick.labelsize': s, 'legend.fontsize': 20}
sns.set(rc=rc_, style="darkgrid")
cblue = sns.color_palette("colorblind")[0]
cgreen = sns.color_palette("colorblind")[1]
cred = sns.color_palette("colorblind")[2]
# rc('text', usetex=True)

parser = argparse.ArgumentParser()
parser.add_argument(
    '--path',
    default='./images',
    help="path"
)
args = parser.parse_args()

# Vary p [0 1], Vary n [0 2] in p^n, Vary rmin [-5 0) 

# #####################################################################################
rmin, rmax = -1, 0
n = 200
penalty = np.linspace(-10,0,n)
images = []

type_p = 1 # 0: p1=p2, 1: p1=0, 2: p2=0
for probi in np.linspace(1,0,50):
    print("Image: ", probi)

    p = np.linspace(1,0,n)
    convergences = []
    successes = np.zeros((n,n))
    minmax_line = [[],[]]
    dc_minmax_line = [[],[]]
    d_minmax_line = [[],[]]
    c_minmax_line = [[],[]]

    if type_p==0:
        p1, p2 = p, p
    if type_p==1:
        p1, p2 = p*0 + probi, p
    if type_p==2:
        p1, p2 = p, p*0 + probi 
        if probi==1:
            continue

    delta_p_s0 = (1-p1) #- p
    delta_p_s0[0] = 0
    delta_p_s0_ = p1 #- (1-p)
    delta_p_s0_[0] = 0
    delta_p_s0_a = np.max([delta_p_s0,delta_p_s0_], axis=0)
    delta_p_s0_b = np.min([delta_p_s0,delta_p_s0_], axis=0)
    delta_p_s0_c = delta_p_s0_a + delta_p_s0_b
    delta_p_s2 = (1-p1) - (1-p1)
    C = delta_p_s0_a

    p = p[1:]
    p1 = p1[1:]
    p2 = p2[1:]
    C = C[1:]

    P = np.zeros((p.shape[0], 4, 4, 2)) # p, S, S, A
    P[:,3,3,:] = 1.0
    P[:,1,1,:] = 1.0
    P[:,2,2,0] = p2
    P[:,2,2,1] = p2
    P[:,2,3,0] = 1-p2
    P[:,2,3,1] = 1-p2
    P[:,0,2,0] = p1
    P[:,0,2,1] = 1-p1
    P[:,0,1,0] = 1-p1
    P[:,0,1,1] = p1
    R = np.ones((p.shape[0], 4, 4, 2)) # p, S, A
    R[:,[1,3],:,:] = 0.0
    V = np.zeros((p.shape[0], 4)) # p, S
    while True:
        V_pre = V.copy()
        for s in range(4):
            V[:,s] = np.array([(np.array([P[:,s,s_,a]*(R[:,s,s_,a] + V[:,s_]) for s_ in range(4)]).sum(axis=0)) for a in range(2)]).max(axis=0)
        if np.abs(V_pre-V).max() <= 0:
            break
    D = V.max(axis=1)

    a = (rmin) + np.zeros(p.shape[0])
    b = (rmin-rmax)*(D/C)
    penaltyR = np.min([a, b], axis=0)
    b = (rmin-rmax)*D
    penaltyD = np.min([a, b], axis=0)
    b = (rmin-rmax)*(1/C)
    penaltyC = np.min([a, b], axis=0)

    probs = np.linspace(1,0,n)
    if type_p==0:
        probs1, probs2 = probs, probs
    if type_p==1:
        probs1, probs2 = probs*0 + probi, probs
    if type_p==2:
        probs1, probs2 = probs, probs*0 + probi 
    for ip, (p1,p2) in enumerate(zip(probs1, probs2)): #[0, 0.1, 0.25]:
        if ip%100==0:
            print(p1,p2)
        # D, C = 2, max([(1-p)-p,p-(1-p)]) # 0.72
        if p1!=1 and p2!=1:
            mm = penaltyR[ip-1]
            idx = np.abs(penalty - mm).argmin()
            if mm>=penalty[0] and mm<=penalty[-1]:
                dc_minmax_line[0].append(ip)
                dc_minmax_line[1].append(idx)

            mm = penaltyD[ip-1]
            idx = np.abs(penalty - mm).argmin()
            if mm>=penalty[0] and mm<=penalty[-1]:
                d_minmax_line[0].append(ip)
                d_minmax_line[1].append(idx)

            mm = penaltyC[ip-1]
            idx = np.abs(penalty - mm).argmin()
            if mm>=penalty[0] and mm<=penalty[-1]:
                c_minmax_line[0].append(ip)
                c_minmax_line[1].append(idx)
        states = 4
        P = np.zeros((penalty.shape[0], states, states, 2)) # p, S, S, A
        P[:,3,3,:] = 1.0
        P[:,1,1,:] = 1.0
        P[:,2,2,:] = np.array([p2,p2])
        P[:,2,3,:] = np.array([1-p2,1-p2])
        P[:,0,1,:] = np.array([1-p1,p1])
        P[:,0,2,:] = np.array([p1,1-p1])
        R = rmin*np.ones((penalty.shape[0], states, states, 2)) # p, S, S, A
        R[:,[1,3],:,:] = 0.0
        R[:,0,1,0] = penalty
        R[:,0,1,1] = penalty
        Q = np.zeros((penalty.shape[0], states, 2)) # p, S
        pi = np.zeros((penalty.shape[0], states)) # p, S

        step=0
        maxstep = 10000
        convergence = np.zeros(penalty.shape[0])
        while True:
            step+=1
            Q_pre = Q.copy()
            for s in range(states):
                Qs = np.array([(np.array([P[:,s,s_,a]*(R[:,s,s_,a] + Q[:,s_].max(axis=1)) for s_ in range(states)]).sum(axis=0)) for a in range(2)])
                Q[:,s,:] = (Q[:,s,:] + (Qs.T-Q[:,s,:])) # Vs
            #for i in range(penalty.shape[0]): 
            #    if np.abs(V_pre[i]-V[i]).max() <= 1e-10 and convergence[i] == 0:
            #        convergence[i] = step
            #    if step>maxstep and convergence[i] == 0:
            #        convergence[i] = maxstep
            if np.abs(Q_pre-Q).max() <= 1e-10 or step>maxstep:
                break
        convergence = (convergence - convergence.min())/(convergence.max() - convergence.min())
        convergences.append(convergence)
        # print("step",step)
        # print("convergence",convergence)
        
        success = np.zeros(penalty.shape[0])
        for i in range(penalty.shape[0]):
            if Q[i,0].argmax() == 1:
                success[i] = (1-p1)
            else:
                success[i] = p1
            if i>0 and success[i] != success[i-1]:
                minmax_line[0].append(ip)
                minmax_line[1].append(i)
        successes[ip,:] = success
    convergences = np.array(convergences)
    minmax_line = np.array(minmax_line)
    dc_minmax_line = np.array(dc_minmax_line)
    d_minmax_line = np.array(d_minmax_line)
    c_minmax_line = np.array(c_minmax_line)
    successes = np.flipud((np.rot90(successes,k=3,axes=(0,1))))
    # #####################################################################################

    cim = plt.imread("./images/cmap_1.png")
    cim = cim[cim.shape[0]//2, :, :]
    cmap = mcolors.ListedColormap(cim)
    cmap="RdYlBu_r"

    print("Plotting")
    fig = plt.figure(dpi=60, clear=True)

    plt.plot(np.arange(0,len(minmax_line[0])),n-minmax_line[1][::-1], label=r'$R_{Minmax}$', color="black", linestyle="--", linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])
    plt.plot(np.arange(0,len(dc_minmax_line[0])),n-dc_minmax_line[1][::-1], label=r'$\bar R_{MIN}$', color="blue", linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])
    plt.plot(np.arange(0,len(d_minmax_line[0])),n-d_minmax_line[1][::-1], label=r'$\bar R_{MAX}$', color="red", linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])

    c = plt.imshow(1-successes, cmap="RdYlBu_r", vmin=0, vmax=1) # cmap="RdYlBu_r"
    fig.colorbar(c,fraction=0.045)
    plt.ylabel(r"Penalty $\in [-10 ~ 0]$")
    if type_p==0:
        plt.xlabel(r"$p_1=p_2=$"+str(i))
    if type_p==1:
        plt.xlabel(r"$p_1={{pi}},p_2\in [0 ~ 1]$".replace("pi", str(round(probi,3))))
    if type_p==2:
        plt.xlabel(r"$p_2={{pi}},p_1\in [0 ~ 1]$".replace("pi", str(round(probi,3))))
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    legend = plt.legend(loc='lower left', labelcolor='white', fancybox=True, framealpha=0.35, frameon=True)
    legend.get_frame().set_facecolor((0, 0, 0, 1))
    # plt.xticks(range(len(penalty)),penalty)
    # plt.yticks(range(len(probs)),probs)
    fig.tight_layout()
    # plt.show()
    fig.tight_layout(pad=0)
    fig.gca().margins(0)
    fig.canvas.draw()
    image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    images.append(image)

imageio.mimsave('images/prop_vs_penalty_bounds_p{0}.mp4'.format(type_p), images, fps=10)