import numpy as np
from src.value_iteration_torchversion import *
from src.helpers import *
import torch

def getMAP_goalmaps(seed, P_a, trajectories, hyperparams, a, individual_map1, individual_map2, diff_map, 
                    lam1, lam2, max_iters = 500, lr = 0.01, gamma=0.9, info={'Neval': 0}, tag=0, width=None, height=None):
    """ obtain the MAP estimates of model parameters
        Time-invariant estimation
        For two agent collaborative foraging in a square arena (Y.C. April 2024)
        Note: N_STATES, N_ACTIONS live in the state-space for single agent
        args:
            seed (int); initialization seed
            P_a (N_STATES**2 X N_ACTIONS**2): labyrinth/gridworld transition matrix
            trajectories (list): list of expert trajectories; each trajectory is a dictionary with 'states (int)' and 'actions (int)' as keys.
            hyperparams (list): current setting of sigmas, of size N_MAPS
            a (array of size N_MAPS x 1): current setting of weights, a[-1] is the weight for the interaction map
            individual_map1 (array of size N_MAPS_INDIVIDUAL x N_STATES): intial guess for agent 1 maps
            individual_map2 (array of size N_MAPS_INDIVIDUAL x N_STATES): intial guess for agent 2 maps
            diff_map (array of size N_MAPS_INTERACTION x N_diff): intial guess for the difference map 
            max_iters (int): number of SGD iterations to optimize this for
            lr (float): learning rate
            gamma (float): discount factor in value iteration
            lam1 (float): l2 coefficient for individual maps
            lam2 (float): l2 coefficient for interaction map
            info: dict with anything that we'd like to store for printing purposes
            tag (int): 0 for centralized control, 1 for independent control, 2 for independent control with prediction
        returns:
            individual_map1_MLEs (3-d array: (max_iters/10) x N_MAPS_INDIVIDUAL x N_STATES): MAP estimates of the individual maps saved after every 10 iterations
            individual_map2_MLEs (3-d array: (max_iters/10) x N_MAPS_INDIVIDUAL x N_STATES): MAP estimates of the individual maps saved after every 10 iterations
            diff_map_MLEs (3-d array: (max_iters/10) x N_MAPS_INTERACTION x N_diff): MAP estimates of the difference maps saved after every 10 iterations
            losses (list): values of the negative log posterior after every iteration
    """   

    torch.manual_seed(seed)
    np.random.seed(seed)

    N_STATES = int(np.sqrt(P_a.shape[0]))
    if width==None and height == None:
        width, height = int(np.sqrt(N_STATES)), int(np.sqrt(N_STATES))

    # concatenate expert trajectories
    assert(len(trajectories)>0), "no expert trajectories found!"
    state_action_pairs = []
    for num, traj in enumerate(trajectories):
        states = np.array(traj['states'])[:,np.newaxis]
        actions = np.array(traj['actions'])[:,np.newaxis]
        if len(states) == len(actions)+1:
            states = np.array(traj['states'][:-1])[:,np.newaxis]
        assert len(states) == len (actions), "states and action sequences dont have the same length"
        state_action_pairs_this_traj = np.concatenate((states, actions), axis=1)
        assert state_action_pairs_this_traj.shape[0]==len(states), "error in concatenation of s,a,s' tuples"
        assert state_action_pairs_this_traj.shape[1]==2, "states and actions are not integers?"
        state_action_pairs.append(state_action_pairs_this_traj)

    # converting to tensors
    P_a = torch.from_numpy(P_a).float()
    a = torch.from_numpy(a).float()
    sigmas = torch.tensor(hyperparams).float()

    # initial value of goal maps
    individual_map1 = torch.from_numpy(individual_map1.astype(float)).float()
    individual_map2 = torch.from_numpy(individual_map2.astype(float)).float()
    diff_map = torch.from_numpy(diff_map.astype(float)).float()
    individual_map1.requires_grad = True
    individual_map2.requires_grad = True
    diff_map.requires_grad = True

    print("Minimizing the negative log likelihood ...")
    print('{0} {1}'.format('# n_iters', 'neg LL'))
    optimizer = torch.optim.Adam([individual_map1, individual_map2, diff_map], lr=lr)
    # saving the losses
    losses = []
    # saving MLE estimates after every 10 iterations
    individual_map1_MLEs = []
    individual_map2_MLEs = []
    diff_map_MLEs = []


    for i in range(max_iters):
        # l2 prior
        loss_prior = lam1*torch.sum(individual_map1**2)+lam1*torch.sum(individual_map2**2)+lam2*torch.sum(diff_map**2)
        # adding this to the loss
        loss = neglogll(individual_map1, individual_map2, diff_map, state_action_pairs, sigmas, a, P_a, gamma, info, tag, 
                        width=width, height=height) + loss_prior
        losses.append(loss.detach().numpy())
        # taking gradient step
        loss.backward(retain_graph=True)
        optimizer.step()
        optimizer.zero_grad()
        if i%10 == 0 or i==max_iters-1:
            individual_map1_MLE = individual_map1.detach().numpy()
            individual_map2_MLE = individual_map2.detach().numpy()
            diff_map_MLE = diff_map.detach().numpy()
            individual_map1_MLEs.append(individual_map1_MLE.copy())
            individual_map2_MLEs.append(individual_map2_MLE.copy())
            diff_map_MLEs.append(diff_map_MLE.copy())
        
    return individual_map1_MLEs, individual_map2_MLEs, diff_map_MLEs, losses




def neglogll(individual_map1, individual_map2, diff_map, state_action_pairs, hyperparams, a, P_a, gamma, info, tag, width, height):
    '''Returns negative log posterior 

        args:
            state_action_pairs (list of len(trajectories), with each element an array: T x (STATE_DIM + ACTION_DIM ))
            hyperparams (tensor): current setting of hyperparams, contains key 'sigmas' which is array of size N_MAPS
            a (2-d tensor: N_MAPS x T) or (N_MAPS x 1)
            P_a (tensor: N_STATES X N_STATES X N_ACTIONS): transition matrix 
            gamma (float): discount 
            info: dict with anything that we'd like to store for printing purposes
            tag (int): 0 for centralized control, 1 for independent control, 2 for independent control with prediction
        returns:
            negL : negative log posterior
    '''
    
    num_trajectories = len(state_action_pairs)
    log_likelihood = getLL_ti(individual_map1, individual_map2, diff_map, state_action_pairs, hyperparams, a, P_a, gamma, info, tag, width, height)
    negL = (-log_likelihood)/num_trajectories

    info['Neval'] = info['Neval']+1
    n_eval = info['Neval']

    print('{0}, {1}'.format(n_eval, negL))
    return negL


def getLL_ti(individual_map1, individual_map2, diff_map, state_action_pairs, hyperparams, a, P_a, gamma, info, tag, width, height):
    """ returns  likelihood at given goal_maps for time-invariant maps
    """

    T = state_action_pairs[0].shape[0]
    N_STATES = P_a.shape[0]
    N_MAPS = a.shape[0]

    assert(individual_map1.shape[0]+individual_map2.shape[0]+diff_map.shape[0]==N_MAPS), "goal maps are not of the appropriate shape"

    # ------------------------------------------------------------------
    # compute the likelihood terms 
    # ------------------------------------------------------------------

    joint_map = create_joint_maps_pytorch(individual_map1, individual_map2, diff_map, width=width, height=height)

    rewards = a.T @ joint_map # 1 x N_states

    assert rewards.shape[0]==1 and rewards.shape[1]==N_STATES,"rewards not computed correctly"

    # policies should be N_STATES x N_ACTIONS
    if tag == 0:
        _, policy = two_agent_value_iteration(P_a, rewards=rewards.T, gamma=gamma, error=0.1)
    elif tag == 1:
        _, policy,p1,p2 = two_agent_value_iteration_independent_control(P_a, rewards=rewards.T, gamma=gamma, error=0.1)
    elif tag==2:
        _, policy,p1,p2 = two_agent_value_iteration_independent_control_uniform_prediction(P_a, rewards=rewards.T, gamma=gamma, error=0.1)
    elif tag==4:
        _, policy,p1,p2 = two_agent_value_iteration_selfish(P_a, rewards=rewards.T, gamma=gamma, error=0.1)

    if 'SingleAgent' in info:
        if info['SingleAgent'] == 1:
            policy = p1
        elif info['SingleAgent'] == 2:
            policy = p2
    
    log_policies = torch.log(policy)

    # compute the ll for all trajectories
    num_trajectories = len(state_action_pairs)
    log_likelihood = 0
    for i in range(num_trajectories):
        states, actions = torch.tensor(state_action_pairs[i][:,0], dtype=torch.long), torch.tensor(state_action_pairs[i][:,1], dtype=torch.long)
        for t in range(len(states)):
            log_likelihood += torch.sum(log_policies[states[t], actions[t]])

    return log_likelihood



