import torch
import math
from copy import deepcopy

@torch.no_grad()
def backward_traj(env, bnet):
    for t in range(env.T):
        back_pol = env.get_backward_pol(bnet)
        actions = env.get_backward_actions(back_pol)
        env.backward(actions)

@torch.no_grad()
def forward_traj(env, fnet, training=False):
    for t in range(env.T):
        pol = env.get_forward_pol(fnet)
        actions = env.get_forward_actions(pol, training=training)
        env.apply(actions)

def backward_trajectory_log_prob(env, forward_net, backward_net, greedy=False):
    f_logp = torch.zeros(env.batch_size, device=env.device)
    b_logp = torch.zeros_like(f_logp)

    active_ids = env.batch_ids

    for t in range(env.T):
        back_pol = env.get_backward_pol(backward_net)
        actions = env.get_backward_actions(back_pol, greedy=greedy)
        mask = actions[active_ids] < env.backward_action_dim
        b_logp[active_ids[mask]] += torch.log(back_pol[active_ids[mask], actions[active_ids][mask]])

        env.backward(actions)

        forward_pol = env.get_forward_pol(forward_net)
        f_logp[active_ids] += torch.log(forward_pol[active_ids, actions[active_ids]])
        active_ids = env.batch_ids[env.is_initial == 0]

    return f_logp - b_logp


def forward_trajectory_log_prob(env, forward_net, backward_net, training=True):
    log_p = torch.zeros(env.batch_size)

    for t in range(env.T):
        active_ids = env.batch_ids[env.stopped == 0]
        pol = env.get_forward_pol(forward_net)
        actions = env.get_forward_actions(pol, training=training)
        log_p[active_ids] += torch.log(pol[active_ids, actions[active_ids]].clamp(1e-9))

        env.apply(actions)
        back_pol = env.get_backward_pol(backward_net)
        mask = actions[active_ids] < env.backward_action_dim
        log_p[active_ids[mask]] -= torch.log(back_pol[active_ids[mask], actions[active_ids][mask]].clamp(1e-9))

    return log_p #[B]

def forward_reward(env, forward_net, backward_net, logz):
    log_p = forward_trajectory_log_prob(env, forward_net, backward_net, training=True)
    return log_p + logz #[B]

def backward_reward(env, forward_net, backward_net, logz):
    log_p = backward_trajectory_log_prob(env, forward_net, backward_net)
    return log_p + logz #[B]

@torch.no_grad()
def marginal_log_reward(env, forward_net, backward_net, logz, batch):
    log_reward = torch.full((env.batch_size,),float('-inf'), device=env.device)
    for _ in range(batch):
        clone_env = deepcopy(env)
        term_log_reward = backward_trajectory_log_prob(clone_env, forward_net, backward_net) + logz
        log_reward = torch.logaddexp(log_reward, term_log_reward)
    log_reward -= math.log(batch)
    return log_reward

if __name__ == "__main__":
    from diffusionGrid_env import DiffGrid
    from diffusionGrid_rewards import build_log_reward_fn
    from diffusionGrid_nets import FourierTimePolicy
    from torch import nn
    import matplotlib.pyplot as plt
    import numpy as np

    size = 12
    log_reward = build_log_reward_fn('8g', R=0.6*size, sigma=1.0, lam=1e-6)
    env = DiffGrid(size=12, batch_size=10, log_reward=log_reward)

    fnet = FourierTimePolicy(hidden_dim=64, num_layers=3, n_freq=4)
    bnet =FourierTimePolicy(hidden_dim=64, num_layers=3, n_freq=4)
    logz = nn.Parameter(torch.tensor(0.0))

    forward_logp = forward_trajectory_log_prob(env, fnet, bnet, [], [])
    backward_logp = backward_trajectory_log_prob(env, fnet, bnet)

    print(f"forward_logp: {forward_logp}")
    print(f"backward_logp: {backward_logp}")

    env.set_full_grid_T()
    log_r_hat = marginal_log_reward(env, fnet, bnet, logz, batch=10)

    env.plot_log_r_hat(log_r_hat, compare_true=False, prob=False, kind="imshow")

    env.plot_log_r_hat(log_r_hat, compare_true=True, prob=True, kind="contour", levels=80)


