# Added time-invariant option for choice
# Adapted for two-agent collaborative task (Y.C. April 2024)
import numpy as np
from src.value_iteration_torchversion import *
from src.helpers import *
import torch

def getMAP_weights(seed, P_a, trajectories, individual_map1, individual_map2, diff_map, 
                   hyperparams, a_init, max_iters = 500, lr = 0.01, gamma=0.9, info={'Neval': 0}, tag=0, 
                   width=None, height=None):
    """ obtain the MAP estimates of map parameters
        For two agent collaborative foraging in a square arena (Y.C. April 2024)
        args:
            N_STATES, N_ACTIONS: state-space for single agent
            P_a (N_STATES**2 X N_ACTIONS**2): gridworld permutation matrix
            trajectories (list): list of expert trajectories; each trajectory is a dictionary with 'states' and 'actions' as keys.
            hyperparams (list): current setting of hyperparams, of size N_MAPS
            individual_map1 (array of size N_MAPS_INDIVIDUAL x N_STATES): intial guess for agent 1 maps
            individual_map2 (array of size N_MAPS_INDIVIDUAL x N_STATES): intial guess for agent 2 maps
            diff_map (array of size N_MAPS_INTERACTION x N_diff): intial guess for the difference map 
            a_init (array of size (N_MAPS) x 1): initial guess for a (T: total # of state-action pairs across trajectories)
            max_iters (int): number of SGD iterations to optimize this for
            lr (float): learning rate
            gamma (float): discount factor in value iteration
            info: dict with anything that we'd like to store for printing purposes
            tag (int): controls which model to fit (default 0: centralized control, 1: independent control, 2: independent control with prediction)
        returns:
            a_MAP (3-d array: (max_iters/10) x (4+N_MAPS) x 1): MAP estimates of the time-varying weghts saved after every 10 iterations
            losses (list): values of the negative log posterior after every iteration
    """   

    torch.manual_seed(seed)
    np.random.seed(seed)
    
    
    # concatenate states and actions in expert trajectories
    assert(len(trajectories)>0), "no expert trajectories found!"
    state_action_pairs = []
    for _, traj in enumerate(trajectories):
        states = np.array(traj['states'])[:,np.newaxis]
        actions = np.array(traj['actions'])[:,np.newaxis]
        if len(states) == len(actions)+1:
            states = np.array(traj['states'][:-1])[:,np.newaxis] # remove the last state with no action selected
        assert len(states) == len (actions), "states and action sequences dont have the same length"
        state_action_pairs_this_traj = np.concatenate((states, actions), axis=1)
        assert state_action_pairs_this_traj.shape[0]==len(states), "error in concatenation of s,a,s' tuples"
        assert state_action_pairs_this_traj.shape[1]==2, "states and actions are not integers?"
        state_action_pairs.append(state_action_pairs_this_traj)

    # converting to tensors
    P_a = torch.from_numpy(P_a).float()
    N_STATES = P_a.shape[0]
    if width == None and height == None:
        width, height = int(torch.sqrt(torch.tensor(N_STATES))), int(torch.sqrt(torch.tensor(N_STATES)))
    goal_maps = create_joint_maps(individual_map1, individual_map2, diff_map, width, height)
    goal_maps = torch.from_numpy(goal_maps).float()
    sigmas = torch.tensor(hyperparams).float()
    N_MAPS = goal_maps.shape[0]
    assert goal_maps.shape[1]==N_STATES, "goal maps should be tensors with length as no. of states"

    # initial value
    a_init = torch.from_numpy(a_init).flatten().float()
    a_init.requires_grad = True

    print("Minimizing the negative log posterior ...")
    print('{0} {1}'.format('# n_iters', 'neg LP'))
    optimizer = torch.optim.Adam([a_init], lr=lr)
    # saving the losses
    losses = []
    # saving MAP estimates after every 10 iterations
    a_MAPs = []

    for i in range(max_iters):
        loss = neglogpost(a_init, state_action_pairs, sigmas, goal_maps, P_a,  gamma, info, tag)
        losses.append(loss.detach().numpy())
        # taking gradient step
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if i%10 == 0 or i==max_iters-1:
            a_MAP = a_init.detach().numpy()
            a_MAP = np.reshape(a_MAP, (N_MAPS, -1))
            a_MAPs.append(a_MAP.copy())

    return a_MAPs, losses


def neglogpost(a, state_action_pairs, hyperparams, goal_maps, P_a, gamma, info, tag):
    '''Returns negative log posterior 
        args:
            a (1-d tensor: 1*N_MAPS) 
            state_action_pairs (list of len(trajectories), with each element an array: T x (STATE_DIM + ACTION_DIM ))
            hyperparams (tensor): current setting of hyperparams, contains key 'sigmas' whick is array of size 3 with elements \sigma_e, \sigma_th, \sigma_ho
            goal maps (tensor of size (4+N_MAPS) x N_STATES**2): columns contains u_e, u_th, u_ho
            P_a (tensor: N_STATES**2 X N_ACTIONS**2): transition matrix 
            gamma (float): discount factor in value iteration
            info: dict with anything that we'd like to store for printing purposes
            tag: controls which model to fit (default 0: centralized control, 1: independent control, 2: independent control with prediction)
        returns:
            negL : negative log posterior
    '''
    num_trajectories = len(state_action_pairs)
    log_prior, log_likelihood = getPosterior(a, state_action_pairs, hyperparams, goal_maps, P_a, gamma, info, tag)    
    negL = (-log_prior-log_likelihood)/num_trajectories
    if 'll' in info:
        negL = (-log_likelihood)/num_trajectories
    info['Neval'] = info['Neval']+1
    n_eval = info['Neval']
    print('{0}, {1}'.format(n_eval, negL))
    return negL


def getPosterior(a, state_action_pairs, hyperparams, goal_maps, P_a,  gamma, info, tag):
    """ returns prior and likelihood 
        Time-invarianat version
        Two agent collaborative task
        assumes that 'a' has a prior mean of 1
        args:
            hyperparams (tensor list): sigmas
            a: 1 x Nmaps
            tag: controls which model to fit (default 0: centralized control, 1: independent control, 2: independent control with prediction)
        returns:
            log_prior: log prior of time-varying weights
            log_likelihood summed over all the state action terms 
    """

    N_STATES = P_a.shape[0]
    N_MAPS = goal_maps.shape[0]
     
    # diagonal of inverse of the sigma matrix
    invSigma_diag = torch.zeros((N_MAPS))
    invSigma_diag = 1 / hyperparams**2
    logdet_invSigma = torch.sum(torch.log(invSigma_diag))
    # calculating the log prior
    logprior = (1 / 2) * (logdet_invSigma - ((a-torch.ones(a.shape))**2 * invSigma_diag).sum())

    # ------------------------------------------------------------------
    # compute the likelihood terms 
    # ------------------------------------------------------------------

    a_reshaped = a.reshape(N_MAPS, -1)
    rewards = a_reshaped.T @ goal_maps # 1 x N_states**2
    assert rewards.shape[0]==1 and rewards.shape[1]==N_STATES,"rewards not computed correctly"
    # policies should be 1 x N_STATES X N_ACTIONS
    if tag == 0:
        _, policy = two_agent_value_iteration(P_a, rewards=rewards.T, gamma=gamma, error=0.01)
    elif tag == 1:
        _, policy, p1, p2 = two_agent_value_iteration_independent_control(P_a, rewards=rewards.T, gamma=gamma, error=0.01)
    elif tag == 2:
        _, policy, p1, p2 = two_agent_value_iteration_independent_control_uniform_prediction(P_a, rewards=rewards.T, gamma=gamma, error=0.01)
    elif tag == 4:
        _, policy, p1, p2 = two_agent_value_iteration_selfish(P_a, rewards=rewards.T, gamma=gamma, error=0.01)
    if 'SingleAgent' in info:
        if info['SingleAgent'] == 1:
            policy = p1
        elif info['SingleAgent'] == 2:
            policy = p2
    log_policies = torch.log(policy)
    # compute the ll for all trajectories
    num_trajectories = len(state_action_pairs)
    log_likelihood = 0
    for i in range(num_trajectories):
        states, actions = torch.tensor(state_action_pairs[i][:,0], dtype=torch.long), torch.tensor(state_action_pairs[i][:,1], dtype=torch.long)
        for t in range(len(states)):
            log_likelihood += torch.sum(log_policies[states[t], actions[t]])

    return logprior, log_likelihood

