# multi-agent inverse RL
# for the chicken task reported in Ong, Nature Neuroscience, 2020

import os
import numpy as np
import pickle
import matplotlib.pyplot as plt
import torch
from torch import softmax
from plot_utils.generate_colormap import generate_colormap
import string



def marginal_estimate(V_straight,V_coop, V_safe, p1_prediction, p2_prediction):
    V1 = torch.hstack([torch.vstack([torch.tensor([0]), V_safe]), torch.vstack([V_straight, V_coop])])
    V2 = torch.hstack([torch.vstack([torch.tensor([0]), V_straight]), torch.vstack([V_safe, V_coop])])
        
    p_joint1 = softmax(V1.reshape((4,1)), dim=0).reshape((2,2)) # array[i,j] = P1(a1=i,a2=j)
    p_joint2 = softmax(V2.reshape((4,1)), dim=0).reshape((2,2)) # array[i,j] = P2(a1=i,a2=j)

    p1_conditioned = p_joint1 / torch.tile(torch.sum(p_joint1,dim=0),(2,1)) # array[i,j] = P1(a1=i|a2=j)
    p2_conditioned = p_joint2 / torch.tile(torch.sum(p_joint2,dim=1)[:,None],(1,2)) # array[i,j] = P2(a2=i|a1=j)

    p1_prediction_probability = softmax(p1_prediction, dim=0) # p1(a2_hat=i), agent1's prediction of agent2 going straight probability
    p2_prediction_probability = softmax(p2_prediction, dim=0) # p2(a1_hat=0), agent2's prediction of agent1 going straight probability   

    p1_estimate = p1_conditioned[0:1,:] @ p1_prediction_probability[:,None]  # policy(a1=0) = p1(a1=0|a2=0)*p1(a2_hat=0) + p1(a1=0|a2=1)*p1(a2_hat=1)
    p2_estimate = p2_conditioned[:,0:1].T @ p2_prediction_probability[:,None]  # policy(a2=0) = p2(a2=0|a1=0)*p2(a1_hat=0) + p2(a2=0|a1=1)*p2(a1_hat=1)

    return p1_estimate, p2_estimate


if __name__=='__main__':

    REC_DIR_NAME = 'recovered_parameters/chicken_task/'
    REC_DIR_NAME = REC_DIR_NAME + 'fit_independent_control_w_prediction/'

    max_iters = 100 # max iters to run SGD for optimization of goal maps and weights durng each outer loop of dirl
    seed = 1 # initialization seed
    lr = 0.01 # learning rate

    Obs = torch.tensor([0.12,0.52,0.11,0.24]).reshape((2,2)) # Obs[i,j] = P(a1=i,a2=j)
    p1 = torch.sum(Obs, dim=1)
    p2 = torch.sum(Obs, dim=0)

    # initial guess
    V_straight = torch.tensor([28.], requires_grad=True).float()
    V_coop = torch.tensor([10.], requires_grad=True).float()
    V_safe = torch.tensor([3.], requires_grad=True).float()
    p1_prediction = torch.tensor([0.1,0.1], requires_grad=True).float()
    p2_prediction = torch.tensor([0.1,0.1], requires_grad=True).float() 

    optimizer = torch.optim.Adam([V_straight, V_coop, V_safe, p1_prediction, p2_prediction], lr=0.01)

    losses = []
    for iter in range(max_iters):
        p1_estimate, p2_estimate = marginal_estimate(V_straight,V_coop, V_safe, p1_prediction, p2_prediction)
        loss = (p1_estimate-p1[0])**2 + (p2_estimate-p2[0])**2
        losses.append(loss.item())
        # if iter % 10 == 0: print(loss.item())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    V1 = torch.hstack([torch.vstack([torch.tensor([0]), V_safe]), torch.vstack([V_straight, V_coop])])
    V2 = torch.hstack([torch.vstack([torch.tensor([0]), V_straight]), torch.vstack([V_safe, V_coop])])
    p1_prediction_probability = softmax(p1_prediction, dim=0) # p1(a2_hat=i), agent1's prediction of agent2 going straight probability
    p2_prediction_probability = softmax(p2_prediction, dim=0) # p2(a1_hat=i), agent2's prediction of agent1 going straight probability   


    # ---------------------------------------------------------
    # Begin Plotting
    # ---------------------------------------------------------
    save_dir = REC_DIR_NAME + 'plots/'
    if os.path.exists(save_dir) == False:
        os.makedirs(save_dir)

    LEGEND_SIZE = 10
    SMALL_SIZE = 15
    BIGGER_SIZE = 20

    plt.rc('font', family='Helvetica')          # change font to Myriad Pro
    plt.rc('font', size=LEGEND_SIZE)          # controls default text sizes
    plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=LEGEND_SIZE)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=LEGEND_SIZE)    # fontsize of the tick labels
    plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
    plt.rcParams.update({"text.usetex": True})
    colors = ['steelblue', '#D85427', 'tab:green', 'k']

    # Plot the value maps and prediction
    fig, axs = plt.subplots(1,4,figsize=(10,3))
    plt.subplots_adjust(left=0.4, bottom=0.3, right=0.9, top=0.9, wspace=0.7, hspace=0.5)

    cmap = plt.get_cmap('viridis')
    new_cmap = generate_colormap(cmap, 0.5, 1)

    payoffs = [[(0,0),(28,3)],[(3,28),(18,18)]]

    ax = axs[0]
    im = ax.imshow(Obs.numpy(), cmap=new_cmap)
    plt.colorbar(im,ax=ax,fraction=0.046, pad=0.04, ticks=[])
    ax.set_title('Monkey choices')
    for i in range(2):
        for j in range(2):
            ax.text(j, i, r'{:.0f}\%'.format(100*Obs[i, j].item()), ha='center', va='center', color='black')
            ax.text(j, i-0.25, str(payoffs[i][j]), ha='center', va='center', color='black')
    ax.axhline(0.5, color='black', linewidth=1)
    ax.axvline(0.5, color='black', linewidth=1)            
    ax.set_ylabel('Agent 1')
    ax.set_xlabel('Agent 2')
    ax.set_xticks([0,1])
    ax.set_xticklabels(['S','Y'])
    ax.set_yticks([0,1])
    ax.set_yticklabels(['S','Y'])

    ax = axs[1]
    ax.bar([0,1], [p1_prediction_probability[0].item(), p2_prediction_probability[0].item()], color=colors)
    ax.set_title('Mutual predictions')
    ax.set_xticks([0,1])
    ax.set_xticklabels(['Agent 1','Agent 2'],rotation=30)
    ax.set_ylabel('Counterpart P(S)')

    ax = axs[2]
    im = ax.imshow(V1.detach().numpy()/np.max(V1.detach().numpy()), cmap=new_cmap)
    c_bar = plt.colorbar(im,ax=ax,fraction=0.046, pad=0.04, ticks=[0,1])
    # c_bar.set_ticks(['High','Low'])
    ax.set_title('V1')
    ax.set_xlabel('Agent 2')
    ax.set_ylabel('Agent 1')
    ax.set_xticks([0,1])
    ax.set_xticklabels(['S','Y'])
    ax.set_yticks([0,1])
    ax.set_yticklabels(['S','Y'])
    # plt.axis('off')

    ax = axs[3]
    im = ax.imshow(V2.detach().numpy()/np.max(V2.detach().numpy()), cmap=new_cmap)
    plt.colorbar(im,ax=ax,fraction=0.046, pad=0.04, ticks=[0,1])
    ax.set_title('V2')
    ax.set_ylabel('Agent 1')
    ax.set_xlabel('Agent 2')
    ax.set_yticks([0,1])
    ax.set_yticklabels(['S','Y'])
    ax.set_xticks([0,1])
    ax.set_xticklabels(['S','Y'])
    # plt.axis('off')

    for n, ax in enumerate(axs.flat):
        ax.text(-0.25, 1.15, string.ascii_uppercase[n+1], transform=ax.transAxes, size=BIGGER_SIZE, weight='bold')

    plt.tight_layout()
    fig.savefig(save_dir + 'values.png')
    # fig.savefig(save_dir + 'values.pdf')
    fig.savefig(save_dir + 'values.eps', transparent=True)
    fig.savefig(save_dir + 'values.svg', transparent=True)


    # # RECOVERED agent 1 map
    # plt.subplot(2, 2, 3)
    # plt.imshow(np.reshape(rec_ind1_maps[0,:],(grid_H, grid_W),order='F'))
    # plt.colorbar()
    # plt.title('rec. agent 1: {:.2f}'.format(rec_weights[0,0]), fontsize=8)
    # # plt.axis('off')


    # load in loss function and plot
    fig, axs = plt.subplots(1,1,figsize=(2.2,2.2))
    plt.subplots_adjust(left=0.3, bottom=0.3, right=0.9, top=0.8)
    axs.plot(losses)
    axs.set_title('Loss')
    axs.set_xlabel('Epochs')
    fig.savefig(save_dir + 'loss.pdf')



