"""MaxCausalEnt IRL with differentiable soft value iteration."""

import torch
import torch.nn as nn
import numpy as np


class SoftValueIteration(nn.Module):
    """Differentiable soft value iteration for small discrete MDPs."""

    def __init__(self, T, gamma=0.99, n_iters=200):
        super().__init__()
        self.register_buffer("T", T)
        self.gamma = gamma
        self.n_iters = n_iters

    def forward(self, reward):
        """Return (V, Q, pi) from soft Bellman backup over reward."""
        S = reward.shape[0]
        V = torch.zeros(S, device=reward.device)
        for _ in range(self.n_iters):
            Q = reward.unsqueeze(-1) + self.gamma * torch.einsum("ijk,k->ij", self.T, V)
            V = torch.logsumexp(Q, dim=-1)
        pi = torch.softmax(Q, dim=-1)
        return V, Q, pi


class MaxCausalEntIRL(nn.Module):
    """MaxCausalEnt IRL with tabular reward parameterization."""

    def __init__(self, n_states, T, gamma=0.99, n_vi_iters=200, l2_reg=0.01):
        super().__init__()
        self.reward_params = nn.Parameter(torch.zeros(n_states))
        self.soft_vi = SoftValueIteration(T, gamma, n_vi_iters)
        self.l2_reg = l2_reg

    def forward(self, demo_sa_pairs):
        """Return (nll, V, Q, pi) for current reward params on demo_sa_pairs."""
        V, Q, pi = self.soft_vi(self.reward_params)
        log_pi = torch.log_softmax(Q, dim=-1)

        nll = 0.0
        n_steps = 0
        for traj in demo_sa_pairs:
            for s, a in traj:
                nll -= log_pi[s, a]
                n_steps += 1
        nll = nll / n_steps
        nll = nll + self.l2_reg * (self.reward_params**2).sum()
        return nll, V, Q, pi


def compute_log_likelihood_mdp(Q_soft, sa_pairs):
    """Log-likelihood under MCE policy in bits/decision."""
    log_pi = torch.log_softmax(Q_soft, dim=-1)
    total_ll = 0.0
    n_steps = 0
    for traj in sa_pairs:
        for s, a in traj:
            total_ll += log_pi[s, a].item()
            n_steps += 1
    return total_ll / (n_steps * np.log(2))


def train_irl(model, train_sa, val_sa=None, n_epochs=300, lr=0.01, print_every=50):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    history = {"train_nll": [], "val_ll_bits": []}

    for epoch in range(n_epochs):
        loss, V, Q, pi = model(train_sa)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            model.reward_params.clamp_(-10, 10)

        history["train_nll"].append(loss.item())

        if val_sa is not None:
            val_ll = compute_log_likelihood_mdp(Q.detach(), val_sa)
            history["val_ll_bits"].append(val_ll)

        if epoch % print_every == 0:
            msg = f"Epoch {epoch}: train_nll={loss.item():.4f}"
            if val_sa is not None:
                msg += f", val_ll={history['val_ll_bits'][-1]:.4f} bits/dec"
            print(msg)

    return history
