"""Evaluation metrics: log-likelihood and accuracy."""

import torch
import numpy as np


def compute_log_likelihood_gru(policy, obs_dataset):
    policy.eval()
    total_ll, n_steps = 0.0, 0
    with torch.no_grad():
        for d in obs_dataset:
            obs = d["obs"].unsqueeze(0)
            logits, _ = policy(obs)
            log_probs = torch.log_softmax(logits.squeeze(0), dim=-1)
            for t in range(len(d["actions"])):
                a = d["actions"][t].item()
                total_ll += log_probs[t, a].item()
                n_steps += 1
    return total_ll / (n_steps * np.log(2))


def compute_prediction_accuracy(policy, obs_dataset):
    policy.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for d in obs_dataset:
            obs = d["obs"].unsqueeze(0)
            logits, _ = policy(obs)
            preds = logits.squeeze(0).argmax(dim=-1)
            for t in range(len(d["actions"])):
                if preds[t].item() == d["actions"][t].item():
                    correct += 1
                total += 1
    return correct / total


def compute_log_likelihood_mlp(policy, obs_dataset):
    policy.eval()
    total_ll, n_steps = 0.0, 0
    with torch.no_grad():
        for d in obs_dataset:
            obs = d["obs"].unsqueeze(0)
            logits = policy(obs)
            log_probs = torch.log_softmax(logits.squeeze(0), dim=-1)
            for t in range(len(d["actions"])):
                a = d["actions"][t].item()
                total_ll += log_probs[t, a].item()
                n_steps += 1
    return total_ll / (n_steps * np.log(2))


def compute_prediction_accuracy_mlp(policy, obs_dataset):
    policy.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for d in obs_dataset:
            obs = d["obs"].unsqueeze(0)
            logits = policy(obs)
            preds = logits.squeeze(0).argmax(dim=-1)
            for t in range(len(d["actions"])):
                if preds[t].item() == d["actions"][t].item():
                    correct += 1
                total += 1
    return correct / total


def compute_per_node_accuracy(policy, obs_dataset):
    """Per-node top-1 accuracy for GRU policy; returns {node_id: (correct, total)}."""
    policy.eval()
    node_stats = {}
    with torch.no_grad():
        for d in obs_dataset:
            obs = d["obs"].unsqueeze(0)
            logits, _ = policy(obs)
            preds = logits.squeeze(0).argmax(dim=-1)
            for t in range(len(d["actions"])):
                node = d["states"][t]
                correct_t = int(preds[t].item() == d["actions"][t].item())
                if node not in node_stats:
                    node_stats[node] = [0, 0]
                node_stats[node][0] += correct_t
                node_stats[node][1] += 1
    return {k: tuple(v) for k, v in node_stats.items()}


def compute_per_node_accuracy_mlp(policy, obs_dataset):
    """Per-node top-1 accuracy for MLP policy; returns {node_id: (correct, total)}."""
    policy.eval()
    node_stats = {}
    with torch.no_grad():
        for d in obs_dataset:
            obs = d["obs"].unsqueeze(0)
            logits = policy(obs)
            preds = logits.squeeze(0).argmax(dim=-1)
            for t in range(len(d["actions"])):
                node = d["states"][t]
                correct_t = int(preds[t].item() == d["actions"][t].item())
                if node not in node_stats:
                    node_stats[node] = [0, 0]
                node_stats[node][0] += correct_t
                node_stats[node][1] += 1
    return {k: tuple(v) for k, v in node_stats.items()}


def compute_behavioral_cloning_ll(train_sa, val_sa, n_states=127, n_actions=4):
    """Frequency-based behavioral cloning log-likelihood in bits/decision."""
    counts = np.zeros((n_states, n_actions))
    for traj in train_sa:
        for s, a in traj:
            counts[s, a] += 1

    counts += 0.01
    policy = counts / counts.sum(axis=1, keepdims=True)
    log_policy = np.log2(policy)

    total_ll, n_steps = 0.0, 0
    for traj in val_sa:
        for s, a in traj:
            total_ll += log_policy[s, a]
            n_steps += 1
    return total_ll / n_steps
