#IRL utils
import torch
from reward_fns import neg_distance, within_radius
import numpy as np

def combine_batches(batch1, batch2):
    """
    batches are list of lists, denoting transitions.
    the top level lists are: states, actions, logprobs, rewards, next_states, done_inds
    batch1 determines the shape (we ignore the last columns of batch2)
    """
    ret = []
    for e1, e2 in zip(batch1[:-1], batch2[:-1]):
        if len(e1.shape) > 1:
            cat_element = torch.cat((e1, e2[:, :e1.shape[1]]), 0)
            ret.append(cat_element)
        else:  
            cat_element = torch.cat((e1, e2), 0)
            ret.append(cat_element)
    
    #done inds require special treatment
    done_inds1 = batch1[-1]
    done_inds2 = batch2[-1]
    cat_element = torch.cat((done_inds1, done_inds2 + done_inds1[-1] + 1), 0)
    ret.append(cat_element)
    return ret


def get_rewards(batch, goal_locs, goal_rads):
    n_weights = goal_locs.shape[0]
    n_env_dims = goal_locs.shape[1]
    states, actions, _, _, next_states, done_inds = batch
    states = states.to("cpu")
    actions = actions.to("cpu")
    next_states = next_states.to("cpu")

    #rew_fn = within_radius
    rew_fn = neg_distance
    rewards = np.zeros((states.shape[0], n_weights))
    for ii in range(states.shape[0]):
        state, action, next_state = states[[ii]].numpy(), [tuple(actions[ii].tolist())], next_states[[ii]].numpy()
        reward = [rew_fn(next_state[:,:n_env_dims], goal.reshape((1,n_env_dims)), rad.reshape((1,1))) for ii, (goal, rad) in enumerate(zip(goal_locs, goal_rads))]
        reward = np.column_stack(reward)
        rewards[ii] = reward

    rewards = torch.from_numpy(rewards).float().T
    gamma = 0.99
    start_inds = torch.cat( (torch.zeros(1),done_inds + 1), 0)
    ep_lens = start_inds[1:] - start_inds[:-1]
    discounts = gamma ** torch.cat([torch.arange(ep_len) for ep_len in ep_lens])
    rewards = rewards * discounts
    return list(rewards)
        
    
