import numpy as np
import torch
import torch.nn.functional as F

def f_div_disc_loss(div: str, IS: bool, samples, disc, reward_func, device, expert_trajs=None):
    # please add eps to expert density, not here
    assert div in ['fkl', 'rkl', 'js']
    s, _, log_a = samples
    if expert_trajs is not None:
        assert expert_trajs.ndim == 3 and expert_trajs.shape[1:] == s.shape[1:]
        s = np.concatenate((s, expert_trajs), axis=0) # NOTE: this won't change samples variable
    N, T, d = s.shape

    s_vec = s.reshape(-1, d)
    logits = disc.log_density_ratio(s_vec) # torch vector

    if div == 'fkl':
        t1 = torch.exp(logits) # (N*T,) p/q TODO: clip
    elif div == 'rkl':
        t1 = logits # (N*T,) log (p/q)
    elif div == 'js':  # https://pytorch.org/docs/master/generated/torch.nn.Softplus.html
        t1 = F.softplus(logits) # (N*T,) log (1 + p/q)

    t1 = (-t1).view(N, T).sum(1) # NOTE: sign (N,)
    t2 = reward_func.r(torch.FloatTensor(s_vec).to(device)).view(N, T).sum(1) # (N,)    

    if IS:
        traj_reward = reward_func.get_scalar_reward(s_vec).reshape(N, T).sum(1) # (N,)
        traj_log_prob = log_a.sum(1) # (N,)
        IS_ratio = F.softmax(torch.FloatTensor(traj_reward - traj_log_prob), dim=0).to(device) # normalized weight
        surrogate_objective = (IS_ratio * t1 * t2).sum() - (IS_ratio * t1).sum() * (IS_ratio * t2).sum()
    else:
        surrogate_objective = (t1 * t2).mean() - t1.mean() * t2.mean() # sample covariance
    
    surrogate_objective /= T

    return surrogate_objective, t1 / T # log of geometric mean w.r.t. traj (0 is the borderline)

def f_div_disc_unbiased_loss(div: str, agent_samples, expert_samples, disc, reward_func, device):
    # use IS for state marginal terms to get expert signal
    assert div in ['fkl-unbiased', 'rkl-unbiased', 'js-unbiased']
    sA, _, log_a = agent_samples
    N, T, d = sA.shape

    sA_vec = sA.reshape(-1, d)
    sA_logits = disc.log_density_ratio(sA_vec) # torch vector
    assert expert_samples.ndim == 2
    sE_logits = disc.log_density_ratio(expert_samples) # torch vector

    if div == 'fkl-unbiased':
        t1A = torch.exp(sA_logits) # (N*T,) p/q TODO: clip
        t1E = torch.exp(sE_logits) # (N*T,) p/q TODO: clip
    elif div == 'rkl-unbiased':
        t1A = sA_logits # (N*T,) log (p/q)
        t1E = sE_logits # (N*T,) log (p/q)
    elif div == 'js-unbiased':  # https://pytorch.org/docs/master/generated/torch.nn.Softplus.html
        t1A = F.softplus(sA_logits) # (N*T,) log (1 + p/q)
        t1E = F.softplus(sE_logits) # (N*T,) log (1 + p/q)

    t1A = (-t1A).view(N, T).sum(1) # NOTE: sign (N,)
    t2A = reward_func.r(torch.FloatTensor(sA_vec).to(device)).view(N, T).sum(1) # (N,)    
    surrogate_objective = (t1A * t2A).mean() # same

    IS_weight = torch.exp(-sE_logits) # q/p
    t1E = (-t1E) * IS_weight # FKL is constant 1
    t2E = reward_func.r(torch.FloatTensor(expert_samples).to(device)).view(-1)
    t2E = t2E * IS_weight
    surrogate_objective -= T**2 * t1E.mean() * t2E.mean()
    
    surrogate_objective /= T

    return surrogate_objective, None # log of geometric mean w.r.t. traj (0 is the borderline)

def unbiased_f_div_disc_loss(div: str, IS: bool, samples, disc, reward_func, device, expert_trajs):
    # please add eps to expert density, not here
    assert div in ['fkl', 'rkl', 'js']
    s, _, log_a = samples

    N, T, d = s.shape

    s_vec = s.reshape(-1, d)
    logits = disc.log_density_ratio(s_vec) # torch vector

    if div == 'fkl':
        t1 = torch.exp(logits) # (N*T,) p/q TODO: clip
    elif div == 'rkl':
        t1 = logits # (N*T,) log (p/q)
    elif div == 'js':  # https://pytorch.org/docs/master/generated/torch.nn.Softplus.html
        t1 = F.softplus(logits) # (N*T,) log (1 + p/q)

    t1 = (-t1).view(N, T).sum(1) # NOTE: sign (N,)
    t2 = reward_func.r(torch.FloatTensor(s_vec).to(device)).view(N, T).sum(1) # (N,)    

    if IS:
        traj_reward = reward_func.get_scalar_reward(s_vec).reshape(N, T).sum(1) # (N,)
        traj_log_prob = log_a.sum(1) # (N,)
        IS_ratio = F.softmax(torch.FloatTensor(traj_reward - traj_log_prob), dim=0).to(device) # normalized weight
        surrogate_objective = (IS_ratio * t1 * t2).sum() - (IS_ratio * t1).sum() * (IS_ratio * t2).sum()
    else:
        surrogate_objective = (t1 * t2).mean() - t1.mean() * t2.mean() # sample covariance
    
    surrogate_objective /= T
    

    unbiased_objective = torch.mean(reward_func.r(torch.from_numpy(s_vec)) * torch.exp(logits).view(-1, 1)) - torch.mean(reward_func.r(torch.from_numpy(expert_trajs).view(-1, d)))

    return surrogate_objective + unbiased_objective, t1 / T # log of geometric mean w.r.t. traj (0 is the borderline)

def f_div_current_state_disc_loss(div: str, samples, disc, reward_func, device, expert_trajs=None):
    ''' NOTE: deprecated for expert samples, maxentirl does not need disc
    '''
    assert div in ['fkl-state', 'rkl-state', 'js-state']
    s, _, _ = samples
    if expert_trajs is not None:
        assert expert_trajs.ndim == 3 and expert_trajs.shape[1:] == s.shape[1:]
        s = np.concatenate((s, expert_trajs), axis=0)
    N, T, d = s.shape

    s_vec = s.reshape(-1, d)
    logits = disc.log_density_ratio(s_vec) # torch vector

    if div == 'fkl-state':
        t1 = torch.exp(logits) # (N*T,) p/q TODO: clip
    elif div == 'rkl-state':
        t1 = logits # (N*T,) log (p/q)
    elif div == 'js-state':
        t1 = F.softplus(logits) # (N*T,) log (1 + p/q)

    t1 = -t1 # (N*T,) not sum
    t2 = reward_func.r(torch.FloatTensor(s_vec).to(device)).view(-1) # (N*T,) not sum

    surrogate_objective = (t1 * t2).mean() - t1.mean() * t2.mean()
    return T * surrogate_objective # same scale
