import numpy as np
from src.value_iteration_torchversion import *
from src.helpers import *
import torch

def get_validation_ll(seed, P_a, trajectories, hyperparams, 
                      individual_map1, individual_map2, diff_map, a, 
                      gamma=0.9, info={'Neval': 0}, tag=0, width=None, height=None):
    """ obtain the ll of the held-out trajectories using a given parameter setting
        Time-invariant
        Two agent collaborative task 
        (Y.C. April 2024)
        args:
            P_a (N_STATES**2 X N_ACTIONS**2): permutation matrix
            trajectories (list): list of expert trajectories; each trajectory is a dictionary with 'states' and 'actions' as keys.
            hyperparams (list): current setting of hyperparams, of size N_MAPS
            a (array of size N_MAPS x T or 1): current setting of weights 
            goal_maps(array of size N_MAPS x N_STATES): goal maps columns contains u_e, u_th, u_ho etc
            gamma (float): discount factor
        returns:
           val_ll (float): validation ll of held-out trajectories
    """   

    torch.manual_seed(seed)
    np.random.seed(seed)

    N_STATES = int(np.sqrt(P_a.shape[0]))
    if width==None and height == None:
        width, height = int(np.sqrt(N_STATES)), int(np.sqrt(N_STATES))

    # concatenate expert trajectories
    assert(len(trajectories)>0), "no expert trajectories found!"
    state_action_pairs = []
    for num, traj in enumerate(trajectories):
        states = np.array(traj['states'])[:,np.newaxis]
        actions = np.array(traj['actions'])[:,np.newaxis]
        if len(states) == len(actions)+1:
            states = np.array(traj['states'][:-1])[:,np.newaxis]
        assert len(states) == len (actions), "states and action sequences dont have the same length"
        T = len(states)
        state_action_pairs_this_traj = np.concatenate((states, actions), axis=1)
        assert state_action_pairs_this_traj.shape[0]==len(states), "error in concatenation of s,a,s' tuples"
        assert state_action_pairs_this_traj.shape[1]==2, "states and actions are not integers?"
        state_action_pairs.append(state_action_pairs_this_traj)

    # converting to tensors
    P_a = torch.from_numpy(P_a).float()
    a = torch.from_numpy(a).float()
    individual_map1 = torch.from_numpy(individual_map1).float()
    individual_map2 = torch.from_numpy(individual_map2).float()
    diff_map = torch.from_numpy(diff_map).float()

    log_likelihood_pairwise = getLL_ti(individual_map1, individual_map2, diff_map, state_action_pairs, hyperparams, a, P_a,
                                        gamma, info, tag, width=width, height=height)

    return log_likelihood_pairwise.item()



def getLL_ti(individual_map1, individual_map2, diff_map, state_action_pairs, hyperparams, a, P_a, gamma, info, tag, width, height):
    """ returns likelihood at given maps and weights for time-invariant maps in a two-agent collaborative task
        Y.C. April 2024
    """

    N_STATES = P_a.shape[0]
    N_MAPS = a.shape[0]

    assert(individual_map1.shape[0]+individual_map2.shape[0]+diff_map.shape[0]==N_MAPS), "goal maps are not of the appropriate shape"

    joint_map = create_joint_maps_pytorch(individual_map1, individual_map2, diff_map, width=width, height=height)

    rewards = a.T @ joint_map # 1 x N_states

    assert rewards.shape[0]==1 and rewards.shape[1]==N_STATES,"rewards not computed correctly"

    # policies should be N_STATES x N_ACTIONS
    if tag == 0:
        _, policy = two_agent_value_iteration(P_a, rewards=rewards.T, gamma=gamma, error=0.1)
    elif tag == 1:
        _, policy, p1, p2 = two_agent_value_iteration_independent_control(P_a, rewards=rewards.T, gamma=gamma, error=0.1)
    elif tag == 2:
        _, policy, p1, p2 = two_agent_value_iteration_independent_control(P_a, rewards=rewards.T, gamma=gamma, error=0.1)
    elif tag == 4:
        _, policy, p1, p2 = two_agent_value_iteration_selfish(P_a, rewards=rewards.T, gamma=gamma, error=0.1)
    
    if 'SingleAgent' in info:
        if info['SingleAgent'] == 1:
            policy = p1
        elif info['SingleAgent'] == 2:
            policy = p2
    
    log_policies = torch.log(policy)

    # compute the ll for all trajectories
    num_trajectories = len(state_action_pairs)
    log_likelihood = 0
    N_pairs = 0
    for i in range(num_trajectories):
        states, actions = torch.tensor(state_action_pairs[i][:,0], dtype=torch.long), torch.tensor(state_action_pairs[i][:,1], dtype=torch.long)
        for t in range(len(states)):
            log_likelihood += log_policies[states[t], actions[t]]
            N_pairs += 1

    return log_likelihood / N_pairs
