"""Bayes-optimal filter for the structural observation POMDP."""
import os
import sys
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)

import numpy as np
import torch
from src.rosenberg_data import (
    load_rosenberg_trajectories, split_rosenberg_trajectories,
    build_bc_targets,
)
from src.evaluation import compute_behavioral_cloning_ll
from src.utils import trajectories_to_sa_pairs

N_STATES = 127
N_ACTIONS = 3
MAX_SEQ_LEN = 200


def build_obs_class_map(n_states=127):
    obs_class = np.zeros(n_states, dtype=int)
    for s in range(n_states):
        if s == 0:
            obs_class[s] = 0
        elif s <= 2:
            obs_class[s] = 1
        elif s <= 30:
            obs_class[s] = 2
        elif s <= 62:
            obs_class[s] = 3
        else:
            obs_class[s] = 4

    n_classes = 5
    obs_masks = np.zeros((n_classes, n_states))
    for s in range(n_states):
        obs_masks[obs_class[s], s] = 1.0

    return obs_class, obs_masks


def build_transition_matrices(n_states=127):
    """Per-action deterministic transition matrices T_a[s, s'] = 1."""
    T = np.zeros((N_ACTIONS, n_states, n_states))
    for s in range(n_states):
        left = 2 * s + 1 if 2 * s + 1 < n_states else s
        right = 2 * s + 2 if 2 * s + 2 < n_states else s
        parent = (s - 1) // 2 if s > 0 else 0
        T[0, s, left] = 1.0
        T[1, s, right] = 1.0
        T[2, s, parent] = 1.0
    return T


def run_filter(val_trajs, obs_class, obs_masks, bc_policy_np, T_action,
               mode="obs_only", max_seq_len=MAX_SEQ_LEN):
    if mode == "obs_only":
        M = np.zeros((N_STATES, N_STATES))
        for a in range(N_ACTIONS):
            M += np.diag(bc_policy_np[:, a]) @ T_action[a]

    node_correct = 0
    node_total = 0
    total_ll = 0.0
    ll_steps = 0
    total_true_prob = 0.0

    per_class_correct = np.zeros(5, dtype=int)
    per_class_total = np.zeros(5, dtype=int)
    belief_entropies = []

    for traj in val_trajs:
        states = traj['states']
        actions = traj['actions']
        T = len(actions)

        for start in range(0, T, max_seq_len):
            end = min(start + max_seq_len, T)
            cs = states[start:end + 1]
            ca = actions[start:end]

            c0 = obs_class[cs[0]]
            belief = obs_masks[c0].copy()
            belief /= belief.sum()

            for t in range(len(ca)):
                true_s = cs[t]
                true_a = int(ca[t])
                c = obs_class[true_s]

                pred = np.argmax(belief)
                if pred == true_s:
                    node_correct += 1
                    per_class_correct[c] += 1
                node_total += 1
                per_class_total[c] += 1

                total_true_prob += belief[true_s]

                bayes_pi = belief @ bc_policy_np
                ll = np.log2(max(bayes_pi[true_a], 1e-15))
                total_ll += ll
                ll_steps += 1

                pos = belief[belief > 1e-30]
                ent = -np.sum(pos * np.log2(pos))
                belief_entropies.append(ent)

                if mode == "obs_only":
                    new_belief = belief @ M
                else:
                    new_belief = belief @ T_action[true_a]

                c_next = obs_class[cs[t + 1]]
                new_belief *= obs_masks[c_next]

                total = new_belief.sum()
                if total > 1e-30:
                    new_belief /= total
                else:
                    new_belief = obs_masks[c_next].copy()
                    new_belief /= new_belief.sum()

                belief = new_belief

    return {
        'node_accuracy': node_correct / node_total,
        'policy_ll': total_ll / ll_steps,
        'mean_true_node_prob': total_true_prob / node_total,
        'mean_belief_entropy': np.mean(belief_entropies),
        'n_steps': ll_steps,
        'per_class_acc': {c: (per_class_correct[c] / max(per_class_total[c], 1))
                          for c in range(5)},
        'per_class_n': {c: int(per_class_total[c]) for c in range(5)},
    }


def no_memory_baseline(val_trajs, obs_class, obs_masks, bc_policy_np,
                       max_seq_len=MAX_SEQ_LEN):
    """Predict from current observation class alone (no memory)."""
    node_correct = 0
    node_total = 0
    total_ll = 0.0
    ll_steps = 0

    per_class_correct = np.zeros(5, dtype=int)
    per_class_total = np.zeros(5, dtype=int)

    for traj in val_trajs:
        states = traj['states']
        actions = traj['actions']
        T = len(actions)

        for start in range(0, T, max_seq_len):
            end = min(start + max_seq_len, T)
            cs = states[start:end + 1]
            ca = actions[start:end]

            for t in range(len(ca)):
                true_s = cs[t]
                true_a = int(ca[t])
                c = obs_class[true_s]

                belief = obs_masks[c] / obs_masks[c].sum()
                bayes_pi = belief @ bc_policy_np
                ll = np.log2(max(bayes_pi[true_a], 1e-15))
                total_ll += ll
                ll_steps += 1

                pred = np.argmax(obs_masks[c])
                if pred == true_s:
                    node_correct += 1
                    per_class_correct[c] += 1
                node_total += 1
                per_class_total[c] += 1

    return {
        'node_accuracy': node_correct / node_total,
        'policy_ll': total_ll / ll_steps,
        'n_steps': ll_steps,
        'per_class_acc': {c: (per_class_correct[c] / max(per_class_total[c], 1))
                          for c in range(5)},
    }


def main():
    trajs = load_rosenberg_trajectories()
    train_trajs, val_trajs = split_rosenberg_trajectories(trajs)
    train_sa = trajectories_to_sa_pairs(train_trajs)
    val_sa = trajectories_to_sa_pairs(val_trajs)

    print(f"{len(trajs)} total, {len(train_trajs)} train, "
          f"{len(val_trajs)} val", flush=True)

    root_starts = sum(1 for t in trajs if t['states'][0] == 0)
    print(f"Bouts starting at root: {root_starts}/{len(trajs)} "
          f"({100*root_starts/len(trajs):.1f}%)", flush=True)

    obs_class, obs_masks = build_obs_class_map()
    T_action = build_transition_matrices()

    bc_policy = build_bc_targets(train_sa, n_states=N_STATES, n_actions=N_ACTIONS,
                                 laplace=1.0)
    bc_policy_np = bc_policy.numpy()
    bc_ll = compute_behavioral_cloning_ll(train_sa, val_sa,
                                          n_states=N_STATES, n_actions=N_ACTIONS)
    print(f"BC ceiling LL: {bc_ll:.4f} bits/dec", flush=True)

    print("\nRunning no-memory baseline...", flush=True)
    no_mem = no_memory_baseline(val_trajs, obs_class, obs_masks, bc_policy_np)
    print(f"Done ({no_mem['n_steps']} steps)", flush=True)

    print("Running observation-only Bayesian filter...", flush=True)
    obs_only = run_filter(val_trajs, obs_class, obs_masks, bc_policy_np,
                          T_action, mode="obs_only")
    print(f"Done ({obs_only['n_steps']} steps)", flush=True)

    print("Running action-informed Bayesian filter...", flush=True)
    act_inf = run_filter(val_trajs, obs_class, obs_masks, bc_policy_np,
                         T_action, mode="action_informed")
    print(f"Done ({act_inf['n_steps']} steps)", flush=True)

    print("\nBayesian Filter Baseline Results ")

    gru_ll = -1.2699
    mlp_ll = -1.2813
    gru_probe_acc = 25.3
    random_ll = -np.log2(3)

    print(f"\n{'Model':<25} {'Node Acc':>9} {'Policy LL':>12} "
          f"{'Belief H':>10} {'P(true)':>8}")
    print(f"{'Random':<25} {'0.8%':>9} {random_ll:>12.4f} "
          f"{'—':>10} {'—':>8}")
    print(f"{'No memory (obs class)':<25} "
          f"{100*no_mem['node_accuracy']:>8.1f}% "
          f"{no_mem['policy_ll']:>12.4f} "
          f"{'—':>10} {'—':>8}")
    print(f"{'MLP (trained)':<25} {'—':>9} {mlp_ll:>12.4f} "
          f"{'—':>10} {'—':>8}")
    print(f"{'GRU (trained)':<25} {gru_probe_acc:>8.1f}% "
          f"{gru_ll:>12.4f} {'—':>10} {'—':>8}")
    print(f"{'Bayes filter (obs-only)':<25} "
          f"{100*obs_only['node_accuracy']:>8.1f}% "
          f"{obs_only['policy_ll']:>12.4f} "
          f"{obs_only['mean_belief_entropy']:>9.2f} "
          f"{obs_only['mean_true_node_prob']:>7.3f}")
    print(f"{'Bayes filter (actions)':<25} "
          f"{100*act_inf['node_accuracy']:>8.1f}% "
          f"{act_inf['policy_ll']:>12.4f} "
          f"{act_inf['mean_belief_entropy']:>9.2f} "
          f"{act_inf['mean_true_node_prob']:>7.3f}")
    print(f"{'BC ceiling (full info)':<25} {'100.0%':>9} "
          f"{bc_ll:>12.4f} {'0.00':>10} {'1.000':>8}")

    obs_gap = obs_only['policy_ll'] - mlp_ll
    gru_gap = gru_ll - mlp_ll
    act_gap = act_inf['policy_ll'] - mlp_ll
    bc_gap = bc_ll - mlp_ll
    print(f"\ngaps vs MLP (bits/dec): GRU {gru_gap:+.4f}, Bayes(obs) {obs_gap:+.4f}, "
          f"Bayes(act) {act_gap:+.4f}, BC ceiling {bc_gap:+.4f}")
    if bc_gap > 0:
        print(f"\nGRU captures {100*gru_gap/bc_gap:.0f}% of MLP-to-BC gap")
        print(f"Bayes(obs) captures {100*obs_gap/bc_gap:.0f}% of MLP-to-BC gap")
        print(f"Bayes(act) captures {100*act_gap/bc_gap:.0f}% of MLP-to-BC gap")

    class_names = ['Root', 'Depth 1', 'Depth 2-4', 'Depth 5', 'Leaves']
    class_sizes = [1, 2, 28, 32, 64]
    gru_probe = [100.0, 60.8, 27.4, 15.4, 15.4]

    print(f"\nPer observation class: node prediction accuracy ")
    print(f"{'Class':<12} {'Nodes':>6} {'No mem':>8} {'GRU':>8} "
          f"{'Bayes(o)':>9} {'Bayes(a)':>9} {'Chance':>8}")
    for c in range(5):
        chance = 100.0 / class_sizes[c]
        nm = 100 * no_mem['per_class_acc'].get(c, 0)
        bo = 100 * obs_only['per_class_acc'][c]
        ba = 100 * act_inf['per_class_acc'][c]
        print(f"{class_names[c]:<12} {class_sizes[c]:>6} {nm:>7.1f}% "
              f"{gru_probe[c]:>7.1f}% {bo:>8.1f}% {ba:>8.1f}% "
              f"{chance:>7.1f}%")

    os.makedirs("checkpoints", exist_ok=True)
    save_data = {
        'obs_only': obs_only,
        'action_informed': act_inf,
        'no_memory': no_mem,
        'bc_ll': bc_ll,
        'gru_ll': gru_ll,
        'mlp_ll': mlp_ll,
    }
    torch.save(save_data, "checkpoints/bayesian_filter_baseline.pt")
    print(f"\nSaved to checkpoints/bayesian_filter_baseline.pt")


if __name__ == "__main__":
    main()
