"""Utility functions: data loading, graph construction, transition matrices."""

import pickle
import numpy as np
import torch
import networkx as nx
import sys
import os

DIRL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "dynamic_irl")
sys.path.insert(0, DIRL_PATH)

DATA_DIR = os.path.join(DIRL_PATH, "data", "mouse_data")


def load_trajectories(restricted=True):
    """Load mouse trajectories from DIRL pickle."""
    fname = "water_restricted_mice_trajs.pickle" if restricted else "water_unrestricted_mice_trajs.pickle"
    with open(os.path.join(DATA_DIR, fname), "rb") as f:
        trajs = pickle.load(f)
    return trajs


def load_train_val_indices(restricted=True):
    prefix = "restricted" if restricted else "unrestricted"
    train_idx = np.load(os.path.join(DATA_DIR, f"{prefix}_train_indices.npy"))
    val_idx = np.load(os.path.join(DATA_DIR, f"{prefix}_val_indices.npy"))
    return train_idx, val_idx


def split_trajectories(trajs, train_idx, val_idx):
    train_trajs = [trajs[i] for i in train_idx]
    val_trajs = [trajs[i] for i in val_idx]
    return train_trajs, val_trajs


def trajectories_to_sa_pairs(trajs):
    sa_trajs = []
    for traj in trajs:
        states = traj["states"]
        actions = traj["actions"]
        sa = [(int(states[t]), int(actions[t])) for t in range(len(actions))]
        sa_trajs.append(sa)
    return sa_trajs


def build_labyrinth_env():
    import importlib.util
    spec = importlib.util.spec_from_file_location(
        "labyrinth_with_stay",
        os.path.join(DIRL_PATH, "src", "envs", "labyrinth_with_stay.py"),
    )
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    env = mod.LabyrinthEnv(reward_state=100, n_states=127)
    return env


def get_transition_tensor(env=None):
    """Get T[s,a,s'] from DIRL environment as (127, 4, 127) float tensor."""
    if env is None:
        env = build_labyrinth_env()
    P_a = env.get_transition_mat()
    T = np.transpose(P_a, (0, 2, 1))
    return torch.tensor(T, dtype=torch.float32)


def build_graph_from_transitions(T):
    n_states = T.shape[0]
    G = nx.Graph()
    G.add_nodes_from(range(n_states))
    for s in range(n_states):
        for a in range(T.shape[1]):
            s_next = T[s, a].argmax().item()
            if s_next != s:
                G.add_edge(s, s_next)
    return G


def build_action_masks(T):
    """Return all-ones (127, 4) action mask tensor."""
    n_states, n_actions = T.shape[0], T.shape[1]
    masks = torch.ones(n_states, n_actions)
    return masks


def get_water_port():
    """Water port node in DIRL's 0-indexed space."""
    return 100


def get_node_depths(G, root=0):
    return dict(nx.single_source_shortest_path_length(G, root))


def binary_tree_layout(G, root=0):
    from collections import defaultdict
    levels = nx.single_source_shortest_path_length(G, root)
    level_nodes = defaultdict(list)
    for node, depth in levels.items():
        level_nodes[depth].append(node)
    pos = {}
    for depth, nodes in level_nodes.items():
        n = len(nodes)
        for i, node in enumerate(sorted(nodes)):
            pos[node] = ((i + 0.5) / n - 0.5, -depth)
    return pos


def load_everything(restricted=True):
    trajs = load_trajectories(restricted)
    train_idx, val_idx = load_train_val_indices(restricted)
    train_trajs, val_trajs = split_trajectories(trajs, train_idx, val_idx)
    train_sa = trajectories_to_sa_pairs(train_trajs)
    val_sa = trajectories_to_sa_pairs(val_trajs)

    env = build_labyrinth_env()
    T = get_transition_tensor(env)
    G = build_graph_from_transitions(T)
    water_port = get_water_port()
    depths = get_node_depths(G)

    return {
        "trajs": trajs,
        "train_trajs": train_trajs,
        "val_trajs": val_trajs,
        "train_sa": train_sa,
        "val_sa": val_sa,
        "T": T,
        "G": G,
        "env": env,
        "water_port": water_port,
        "depths": depths,
    }
