"""Data loader for Rosenberg maze trajectories (3-action MDP)."""

import pickle
import os
import sys
import numpy as np
import torch

_lev0 = [0]
_lev1 = [1, 2]
_lev2 = [3, 4, 6, 5]
_lev3 = [8, 7, 9, 10, 13, 14, 12, 11]
_lev4 = [18, 17, 15, 16, 19, 20, 22, 21, 27, 28, 30, 29, 26, 25, 23, 24]
_lev5 = [37, 38, 36, 35, 32, 31, 33, 34, 40, 39, 41, 42, 45, 46, 44, 43,
         56, 55, 57, 58, 61, 62, 60, 59, 53, 54, 52, 51, 48, 47, 49, 50]
_lev6 = [75, 76, 78, 77, 74, 73, 71, 72, 66, 65, 63, 64, 67, 68, 70, 69,
         82, 81, 79, 80, 83, 84, 86, 85, 91, 92, 94, 93, 90, 89, 87, 88,
         114, 113, 111, 112, 115, 116, 118, 117, 123, 124, 126, 125, 122, 121, 119, 120,
         107, 108, 110, 109, 106, 105, 103, 104, 98, 97, 95, 96, 99, 100, 102, 101]
_LEVELORDER = _lev0 + _lev1 + _lev2 + _lev3 + _lev4 + _lev5 + _lev6

_ROS_TO_DIRL = {}
for idx, ros_node in enumerate(_LEVELORDER):
    _ROS_TO_DIRL[ros_node] = idx

REWARDED_ANIMALS = ['B1', 'B2', 'B3', 'B4', 'C1', 'C3', 'C6', 'C7', 'C8', 'C9']

REPO_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
                        'Rosenberg-2021-Repository')
DATA_DIR = os.path.join(REPO_DIR, 'outdata - tf files only')
CODE_DIR = os.path.join(REPO_DIR, 'code')

if CODE_DIR not in sys.path:
    sys.path.insert(0, CODE_DIR)


def rosenberg_to_dirl_node(ros_node):
    """Convert Rosenberg node (0-126) to 0-indexed DIRL node."""
    return _ROS_TO_DIRL[ros_node]


def infer_action(s, s_next):
    """Infer which of 3 actions caused a transition in the 0-indexed binary tree."""
    if s_next == 2 * s + 1:
        return 0
    elif s_next == 2 * s + 2:
        return 1
    elif s > 0 and s_next == (s - 1) // 2:
        return 2
    elif s == 0 and s_next == 0:
        return 2
    else:
        raise ValueError(f"Non-adjacent transition: {s} -> {s_next}")


def load_rosenberg_trajectories(animals=None, min_bout_length=2):
    """Load Rosenberg -tf pickle files and convert to trajectory dicts."""
    if animals is None:
        animals = REWARDED_ANIMALS

    trajs = []
    for name in animals:
        filepath = os.path.join(DATA_DIR, name + '-tf')
        with open(filepath, 'rb') as f:
            tf = pickle.load(f)

        for bout_idx in range(len(tf.no)):
            raw_nodes = tf.no[bout_idx][:, 0].astype(int)

            raw_nodes = raw_nodes[raw_nodes != 127]
            if len(raw_nodes) < min_bout_length:
                continue

            deduped = [raw_nodes[0]]
            for n in raw_nodes[1:]:
                if n != deduped[-1]:
                    deduped.append(n)

            if len(deduped) < min_bout_length:
                continue

            dirl_nodes = [rosenberg_to_dirl_node(n) for n in deduped]

            actions = []
            for t in range(len(dirl_nodes) - 1):
                actions.append(infer_action(dirl_nodes[t], dirl_nodes[t + 1]))

            trajs.append({
                'states': dirl_nodes,
                'actions': np.array(actions, dtype=np.int64),
                'animal': name,
            })

    return trajs


def split_rosenberg_trajectories(trajs, val_fraction=0.2, seed=42):
    rng = np.random.RandomState(seed)
    n = len(trajs)
    indices = rng.permutation(n)
    n_val = int(n * val_fraction)
    val_idx = indices[:n_val]
    train_idx = indices[n_val:]
    train_trajs = [trajs[i] for i in sorted(train_idx)]
    val_trajs = [trajs[i] for i in sorted(val_idx)]
    return train_trajs, val_trajs


def build_binary_tree_transition_tensor(n_states=127, n_actions=3):
    T = torch.zeros(n_states, n_actions, n_states)
    for s in range(n_states):
        left = 2 * s + 1
        right = 2 * s + 2
        parent = (s - 1) // 2 if s > 0 else 0

        if left < n_states:
            T[s, 0, left] = 1.0
        else:
            T[s, 0, s] = 1.0

        if right < n_states:
            T[s, 1, right] = 1.0
        else:
            T[s, 1, s] = 1.0

        T[s, 2, parent] = 1.0

    return T


def build_bc_targets(train_sa, n_states=127, n_actions=3, laplace=1.0):
    """Empirical action frequencies as a behavioral cloning target policy."""
    counts = np.zeros((n_states, n_actions))
    for traj in train_sa:
        for s, a in traj:
            counts[s, a] += 1
    counts += laplace
    policy = counts / counts.sum(axis=1, keepdims=True)
    return torch.tensor(policy, dtype=torch.float32)


def load_rosenberg_everything(animals=None, val_fraction=0.2, seed=42):
    """One-shot loader; returns dict matching load_everything() format plus n_actions=3."""
    from src.utils import build_graph_from_transitions, get_node_depths, trajectories_to_sa_pairs

    trajs = load_rosenberg_trajectories(animals)
    train_trajs, val_trajs = split_rosenberg_trajectories(trajs, val_fraction, seed)
    train_sa = trajectories_to_sa_pairs(train_trajs)
    val_sa = trajectories_to_sa_pairs(val_trajs)

    T = build_binary_tree_transition_tensor()
    G = build_graph_from_transitions(T)
    depths = get_node_depths(G)

    return {
        'trajs': trajs,
        'train_trajs': train_trajs,
        'val_trajs': val_trajs,
        'train_sa': train_sa,
        'val_sa': val_sa,
        'T': T,
        'G': G,
        'depths': depths,
        'n_actions': 3,
    }
