"""Radial arm maze environment for memory-dependent foraging."""
import numpy as np
import torch
import networkx as nx

N_ARMS = 8
ARM_LENGTH = 3
N_NODES = 1 + N_ARMS * ARM_LENGTH  # 25
N_ACTIONS = N_ARMS + 2  # 10: 8 hub entries + advance + retreat
ACT_ADVANCE = N_ARMS      # 8
ACT_RETREAT = N_ARMS + 1  # 9



def arm_of(node):
    """Return arm index (0..7) for a non-hub node, or -1 for hub."""
    if node == 0:
        return -1
    return (node - 1) // ARM_LENGTH


def pos_in_arm(node):
    """Return position within arm (0=proximal, 1=medial, 2=tip), or -1 for hub."""
    if node == 0:
        return -1
    return (node - 1) % ARM_LENGTH


def arm_node(arm, pos):
    return 1 + arm * ARM_LENGTH + pos


def radial_distance(node):
    if node == 0:
        return 0
    return pos_in_arm(node) + 1


def is_tip(node):
    return node > 0 and pos_in_arm(node) == ARM_LENGTH - 1


def is_hub(node):
    return node == 0



def build_graph():
    G = nx.Graph()
    G.add_nodes_from(range(N_NODES))
    for a in range(N_ARMS):
        G.add_edge(0, arm_node(a, 0))
        for p in range(ARM_LENGTH - 1):
            G.add_edge(arm_node(a, p), arm_node(a, p + 1))
    return G


def build_transition_tensor():
    T = np.zeros((N_NODES, N_ACTIONS, N_NODES), dtype=np.float32)

    for node in range(N_NODES):
        if is_hub(node):
            for a in range(N_ARMS):
                T[node, a, arm_node(a, 0)] = 1.0
            T[node, ACT_ADVANCE, node] = 1.0
            T[node, ACT_RETREAT, node] = 1.0
        else:
            for a in range(N_ARMS):
                T[node, a, node] = 1.0

            arm = arm_of(node)
            pos = pos_in_arm(node)

            if pos < ARM_LENGTH - 1:
                T[node, ACT_ADVANCE, arm_node(arm, pos + 1)] = 1.0
            else:
                T[node, ACT_ADVANCE, node] = 1.0  # tip: self-loop

            if pos > 0:
                T[node, ACT_RETREAT, arm_node(arm, pos - 1)] = 1.0
            else:
                T[node, ACT_RETREAT, 0] = 1.0  # proximal: back to hub

    return torch.tensor(T, dtype=torch.float32)


def valid_actions(node):
    if is_hub(node):
        return list(range(N_ARMS))  # enter arm 0..7
    actions = []
    pos = pos_in_arm(node)
    if pos < ARM_LENGTH - 1:
        actions.append(ACT_ADVANCE)
    actions.append(ACT_RETREAT)
    return actions


def action_mask_for(node):
    """Return boolean mask (N_ACTIONS,) — True for valid actions."""
    mask = np.zeros(N_ACTIONS, dtype=bool)
    for a in valid_actions(node):
        mask[a] = True
    return mask



def _node_features(node):
    """3-dim feature vector for a single node: [degree/8, is_hub, is_tip]."""
    if is_hub(node):
        return [N_ARMS / 8.0, 1.0, 0.0]
    elif is_tip(node):
        return [1.0 / 8.0, 0.0, 1.0]
    elif pos_in_arm(node) == 0:
        return [2.0 / 8.0, 0.0, 0.0]
    else:
        return [2.0 / 8.0, 0.0, 0.0]


def structural_obs_encoding():
    obs = {}
    for node in range(N_NODES):
        feat = list(_node_features(node))

        if is_hub(node):
            feat.extend(_node_features(0))
            feat.extend(_node_features(arm_node(0, 0)))
        elif pos_in_arm(node) == 0:
            feat.extend(_node_features(0))
            feat.extend(_node_features(arm_node(arm_of(node), 1)))
        elif is_tip(node):
            feat.extend(_node_features(arm_node(arm_of(node), pos_in_arm(node) - 1)))
            feat.extend([0.0, 0.0, 0.0])
        else:
            arm = arm_of(node)
            pos = pos_in_arm(node)
            feat.extend(_node_features(arm_node(arm, pos - 1)))
            feat.extend(_node_features(arm_node(arm, pos + 1)))

        while len(feat) < 12:
            feat.append(0.0)

        obs[node] = torch.tensor(feat[:12], dtype=torch.float32)

    return obs


def random_obs_encoding(n_classes=4, obs_dim=12, seed=0):
    """Assign each node to a random class with a random observation vector."""
    rng = np.random.RandomState(seed)
    prototypes = rng.randn(n_classes, obs_dim).astype(np.float32)
    assignments = rng.randint(0, n_classes, size=N_NODES)
    obs = {}
    for node in range(N_NODES):
        obs[node] = torch.tensor(prototypes[assignments[node]])
    return obs



def _policy_distribution(node, visited, optimal_prob):
    probs = np.zeros(N_ACTIONS, dtype=np.float32)

    if is_hub(node):
        unvisited = [a for a in range(N_ARMS) if not (visited & (1 << a))]
        if not unvisited:
            for a in range(N_ARMS):
                probs[a] = 1.0 / N_ARMS
        else:
            n_unvis = len(unvisited)
            for a in range(N_ARMS):
                if a in unvisited:
                    probs[a] = (optimal_prob / n_unvis
                                + (1 - optimal_prob) / N_ARMS)
                else:
                    probs[a] = (1 - optimal_prob) / N_ARMS
    elif is_tip(node):
        probs[ACT_RETREAT] = 1.0
    else:
        arm = arm_of(node)
        arm_visited = bool(visited & (1 << arm))
        if arm_visited:
            probs[ACT_RETREAT] = optimal_prob
            probs[ACT_ADVANCE] = 1 - optimal_prob
        else:
            probs[ACT_ADVANCE] = optimal_prob
            probs[ACT_RETREAT] = 1 - optimal_prob

    return probs


def generate_foraging_trajectories(n_trajs=1500, max_steps=200,
                                   optimal_prob=0.8, seed=42):
    """Generate foraging trajectories using a noisy-optimal policy."""
    rng = np.random.default_rng(seed)
    T_node = build_transition_tensor().numpy()
    trajectories = []

    for _ in range(n_trajs):
        node = 0
        visited = 0
        states = [node]
        actions = []
        targets = []
        visited_seq = [visited]

        for _ in range(max_steps):
            probs = _policy_distribution(node, visited, optimal_prob)
            targets.append(probs)

            valid = action_mask_for(node)
            p = probs.copy()
            p[~valid] = 0.0
            p = p / p.sum()
            action = rng.choice(N_ACTIONS, p=p)

            next_node = int(T_node[node, action].argmax())

            if is_tip(next_node):
                visited = visited | (1 << arm_of(next_node))

            actions.append(action)
            node = next_node
            states.append(node)
            visited_seq.append(visited)

            if visited == (1 << N_ARMS) - 1 and is_hub(node):
                break

        trajectories.append({
            'states': states,
            'actions': np.array(actions, dtype=np.int64),
            'visited_seq': visited_seq,
            'targets': np.array(targets, dtype=np.float32),
        })

    return trajectories



def split_trajectories(trajs, val_fraction=0.2, seed=42):
    rng = np.random.RandomState(seed)
    n = len(trajs)
    indices = rng.permutation(n)
    n_val = int(n * val_fraction)
    val_idx = indices[:n_val]
    train_idx = indices[n_val:]
    return [trajs[i] for i in sorted(train_idx)], [trajs[i] for i in sorted(val_idx)]


def trajectories_to_sa_pairs(trajs):
    sa_trajs = []
    for traj in trajs:
        states = traj['states']
        actions = traj['actions']
        sa = [(int(states[t]), int(actions[t])) for t in range(len(actions))]
        sa_trajs.append(sa)
    return sa_trajs


def build_bc_targets(train_sa, laplace=1.0):
    counts = np.zeros((N_NODES, N_ACTIONS))
    for traj in train_sa:
        for s, a in traj:
            counts[s, a] += 1
    for node in range(N_NODES):
        mask = action_mask_for(node)
        counts[node, mask] += laplace
        counts[node, ~mask] = 0.0
    totals = counts.sum(axis=1, keepdims=True)
    totals = np.maximum(totals, 1e-10)
    policy = counts / totals
    return torch.tensor(policy, dtype=torch.float32)


def build_obs_dataset(trajs, obs_encoding):
    dataset = []
    for traj in trajs:
        states = traj['states']
        actions = traj['actions']
        traj_targets = traj['targets']
        T = len(actions)

        obs_seq = torch.stack([obs_encoding[s] for s in states[:T]])
        target_seq = torch.tensor(traj_targets[:T], dtype=torch.float32)
        action_seq = torch.tensor([int(a) for a in actions], dtype=torch.long)
        mask_seq = torch.tensor(
            np.stack([action_mask_for(s) for s in states[:T]]),
            dtype=torch.bool,
        )

        dataset.append({
            'obs': obs_seq,          # (T, obs_dim)
            'targets': target_seq,   # (T, N_ACTIONS)
            'actions': action_seq,   # (T,)
            'states': [int(s) for s in states[:T]],
            'action_mask': mask_seq, # (T, N_ACTIONS)
        })

    return dataset



def load_radial_arm_everything(n_trajs=1500, max_steps=200, optimal_prob=0.8,
                               val_fraction=0.2, seed=42):
    print(f"Generating {n_trajs} foraging trajectories (optimal_prob={optimal_prob})...",
          flush=True)
    trajs = generate_foraging_trajectories(
        n_trajs=n_trajs, max_steps=max_steps,
        optimal_prob=optimal_prob, seed=seed,
    )

    n_tips_visited = []
    traj_lengths = []
    for traj in trajs:
        final_visited = traj['visited_seq'][-1]
        n_tips_visited.append(bin(final_visited).count('1'))
        traj_lengths.append(len(traj['actions']))
    print(f"Mean tips visited: {np.mean(n_tips_visited):.1f}/8", flush=True)
    print(f"Mean trajectory length: {np.mean(traj_lengths):.0f} steps", flush=True)

    train_trajs, val_trajs = split_trajectories(trajs, val_fraction, seed)
    train_sa = trajectories_to_sa_pairs(train_trajs)
    val_sa = trajectories_to_sa_pairs(val_trajs)

    obs_enc = structural_obs_encoding()

    G = build_graph()
    T_node = build_transition_tensor()

    train_dataset = build_obs_dataset(train_trajs, obs_enc)
    val_dataset = build_obs_dataset(val_trajs, obs_enc)

    return {
        'trajs': trajs,
        'train_trajs': train_trajs,
        'val_trajs': val_trajs,
        'train_sa': train_sa,
        'val_sa': val_sa,
        'obs_encoding': obs_enc,
        'G': G,
        'T_node': T_node,
        'train_dataset': train_dataset,
        'val_dataset': val_dataset,
        'n_actions': N_ACTIONS,
        'n_nodes': N_NODES,
    }
