import numpy as np
import matplotlib.pyplot as plt

from tasks.example import MoldovanGridworldMDP, ValueIteration

ticksize = 20
textsize = 24
figsize = (6, 6)
import matplotlib
matplotlib.use('pgf')
plt.rc('pdf', fonttype=42)
plt.rc('ps', fonttype=42)
plt.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
plt.rc('font', size=textsize)  # controls default text sizes
plt.rc('axes', titlesize=textsize)  # fontsize of the axes title
plt.rc('axes', labelsize=textsize)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=ticksize)  # fontsize of the tick labels
plt.rc('ytick', labelsize=ticksize)  # fontsize of the tick labels
plt.rc('legend', fontsize=ticksize)  # legend fontsize

save_fig = True

maze = np.array([
    [' ', '$', ' ', ' ', ' '],
    ['#', ' ', ' ', ' ', ' '],
    ['#', ' ', '%', ' ', ' '],
    [' ', ' ', ' ', ' ', ' '],
    [' ', ' ', ' ', ' ', ' ']
])


def build_source_tasks():
    task1 = MoldovanGridworldMDP(fail_cost=0., fail_cost2=0.)
    task2 = MoldovanGridworldMDP(fail_cost=20., fail_cost2=20.)
    return [task1, task2]


def build_target_task():
    return MoldovanGridworldMDP(fail_cost=20, fail_cost2=0.)


def solve_policies():
    policies = []
    for task in build_source_tasks() + [build_target_task()]:
        vi = ValueIteration(task, beta=-0.1)
        q, pi = vi.solve()
        policies.append(pi)
    return policies


def evaluate_policies(policies, beta):
    target = build_target_task()
    qvalues = []
    for pi in policies:
        vi = ValueIteration(target, beta=beta)
        q, pi = vi.solve(pi)
        qvalues.append(q)
    return qvalues


def build_gpi_policy(policies, qvalues):
    task_index = np.zeros((5, 5), dtype=int)
    actions = np.zeros((5, 5), dtype=int)
    for y in range(5):
        for x in range(5):
            q = np.vstack((qvalues[0][y, x,:], qvalues[1][y, x,:]))
            c = task_index[y, x] = np.argmax(np.max(q, axis=1))
            actions[y, x] = policies[c][y, x]
    return actions, task_index


def test_policy(policy, n_test=1000):
    task = build_target_task()
    return [task.policy_rollout(policy) for _ in range(n_test)]

        
# plotting utilities
def trap(x, y, text):
    plt.text(x, y, text, ha='center', va='center', color='darkred', fontsize=50)


def start(x, y):
    plt.text(x, y, 'S', ha='center', va='center', color='black', fontsize=50)


def goal(x, y):
    plt.text(x, y, 'G', ha='center', va='center', color='black', fontsize=50)


def take_step(x, y, d):
    if d in {'R', 0}: return y, x + 1
    elif d in {'U', 1}: return y - 1, x
    elif d in {'L', 2}: return y, x - 1
    else: return y + 1, x


def policy_path(policy):
    s = (4, 0)
    coords = [s]
    while True:
        s = take_step(s[1], s[0], policy[s[0], s[1]])
        coords.append(s)
        if s == (0, 1): break
    return coords


def plot_environment(policies):
    
    # draw cells
    img = np.zeros((5, 5)) * np.nan
    fig, ax = plt.subplots(figsize=figsize)
    ax.set_aspect('equal')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.patch.set_edgecolor('black')  
    ax.patch.set_linewidth('5')    
    ax.imshow(img)
    ax.set_xticks(np.arange(-.5, 5), minor=True)
    ax.set_yticks(np.arange(-.5, 5), minor=True)
    ax.grid(which='minor', color='grey', linestyle='--', linewidth=2)
    
    # draw start point, traps and destination
    start(0, 4)
    for y in range(5):
        for x in range(5):
            if maze[y, x] == '#': trap(x, y, 'X')
            elif maze[y, x] == '%': trap(x, y, 'Y')
            elif maze[y, x] == '$': goal(x, y)  
    
    # plot the source policy paths
    for policy, col, offset, style in zip(policies, ['red', 'blue', 'green'], [(0.7, 0.0), (0.0, 0.7), (0.0, 0.7)],
                                          ['solid', 'solid', 'dashed']):
        path = policy_path(policy)
        if col == 'red':
            path = [((p[0] - 0.2, p[1]) if p[0] == 4 else p) for p in path]
        elif col == 'blue':
            path = [((p[0] + 0.2, p[1]) if p[0] == 4 else p) for p in path]
        else:
            path = [((p[0] + 0.2, p[1]) if p[0] == 0 else p) for p in path]
        ys = [p[0] for p in path]
        xs = [p[1] for p in path]
        xs[0] += 0.3
        for i in range(0, len(path) - 2):
            plt.plot(xs[i:i + 2], ys[i:i + 2], linestyle=style, linewidth=8, color=col)
        plt.arrow(xs[-2], ys[-2], xs[-1] - xs[-2] + offset[1], ys[-1] - ys[-2] + offset[0],
                  linewidth=8, color=col, head_width=0.2)
    plt.tight_layout()
    if save_fig:
        plt.savefig('envmap.pdf', format='pdf')
    else:
        plt.show()


def quiver(policy):
    X, Y = np.meshgrid(np.arange(0, 5), np.arange(0, 5))
    U = np.zeros_like(X)
    V = np.zeros_like(Y)
    for y in range(5):
        for x in range(5):
            U[y, x] = 1 if policy[y, x] in {'R', 0} else (-1 if policy[y, x] in {'L', 2} else 0)
            V[y, x] = -1 if policy[y, x] in {'D', 3} else (1 if policy[y, x] in {'U', 1} else 0)
            if (y, x) in {(1, 0), (2, 0), (2, 2), (0, 1)}:
                U[y, x] = V[y, x] = 0
    
    path = np.zeros((5, 5)) * np.nan
    s = (4, 0)
    taboo = set()
    taboo.add(s)
    while s != (0, 1):
        path[s[0], s[1]] = 1.
        s = take_step(s[1], s[0], policy[s[0], s[1]])
        if s in taboo: break
        taboo.add(s)
    return X, Y, U, V, path


def quiver_color_rgb(c, p):
    rgb = np.zeros((5, 5)) 
    for y in range(5):
        for x in range(5):
            if p[y, x] == 1:
                rgb[y, x] = 1. if c[y, x] == 0 else -1.
            else:
                rgb[y, x] = np.nan
    return rgb


def plot_quiver(X, Y, U, V, p, cmap):
    _, ax = plt.subplots(figsize=figsize)
    ax.quiver(X, Y, U, V, pivot='middle', scale=9., headwidth=6., headlength=6.)
    ax.imshow(p, cmap=cmap, vmin=-1., vmax=1., alpha=.25)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.patch.set_edgecolor('black')  
    ax.patch.set_linewidth('5')  
    ax.set_xticks(np.arange(-.5, 5), minor=True)
    ax.set_yticks(np.arange(-.5, 5), minor=True)
    ax.grid(which='minor', color='grey', linestyle='--', linewidth=2)
    for y in range(5):
        for x in range(5):
            if maze[y, x] == '#': trap(x, y, 'X')
            elif maze[y, x] == '%': trap(x, y, 'Y')
            elif maze[y, x] == '$': goal(x, y)
    plt.tight_layout()


# get source policies
print('solving source policies')
policies = solve_policies()
print('done')
plot_environment(policies)

# GPI with risk-aware criterion
print('\nevaluating source policies in target task risk-aware')
qvalues_ra = evaluate_policies(policies[:2], beta=-0.1)
print('done')
gpi_action_ra, task_indices_ra = build_gpi_policy(policies[:2], qvalues_ra)
X, Y, U, V, p = quiver(gpi_action_ra)
plot_quiver(X, Y, U, V, quiver_color_rgb(task_indices_ra, p), cmap='seismic')
if save_fig:
    plt.savefig('risk-aware-gpi.pdf', format='pdf')
else:
    plt.show()
    
# GPI with risk neutral criterion
print('\nevaluating source policies in target task risk-neutral')
qvalues = evaluate_policies(policies[:2], beta=0)
print('done')
gpi_action, task_indices = build_gpi_policy(policies[:2], qvalues)
X, Y, U, V, p = quiver(gpi_action)
plot_quiver(X, Y, U, V, quiver_color_rgb(task_indices, p), cmap='seismic')
if save_fig:
    plt.savefig('risk-neutral-gpi.pdf', format='pdf')
else:
    plt.show()
    
# rollouts
print('\ntesting risk-aware GPI')
rollouts_ra = test_policy(gpi_action_ra, 5000)
print('done')

print('\ntesting risk-neutral GPI')
rollouts = test_policy(gpi_action, 5000)
print('done')

plt.figure(figsize=(8, 4))
datas = [rollouts_ra, rollouts]
agent_names = ['risk-aware', 'risk-neutral']
cols = ['darkblue', 'darkred']
bins = np.linspace(-32., 16., 40)
for i in range(2):
    data = datas[i]
    name = agent_names[i]
    col = cols[i]
    plt.hist(data, density=True, bins=bins, alpha=0.5, color=col, label='{} GPI'.format(name))
ax = plt.gca()
ax.spines['right'].set_visible(False)   
ax.spines['top'].set_visible(False) 
plt.xlabel('episode return')
plt.ylabel('density')
plt.legend(frameon=False)
plt.tight_layout()
if save_fig:
    plt.savefig('gpi_test_comparison.pdf', format='pdf')
else:
    plt.show()
