"""Passive T-Maze environment for memory validation."""

import numpy as np
import torch
import gymnasium as gym

N_ACTIONS = 3  # 0=forward, 1=left, 2=right


def n_positions(corridor_length):
    return corridor_length + 3


def n_states(corridor_length):
    return 2 * n_positions(corridor_length)


def encode_state(position, cue):
    return position * 2 + cue


def decode_state(state):
    return state // 2, state % 2


class PassiveTMaze(gym.Env):
    """Passive T-Maze: binary cue at start, forced corridor, junction choice."""

    metadata = {"render_modes": []}

    def __init__(self, corridor_length=100):
        super().__init__()
        self.corridor_length = corridor_length
        self.observation_space = gym.spaces.Box(
            low=0, high=1, shape=(3,), dtype=np.float32,
        )
        self.action_space = gym.spaces.Discrete(N_ACTIONS)

    def reset(self, seed=None, **kwargs):
        super().reset(seed=seed)
        self.cue = self.np_random.integers(0, 2)
        self.position = 0
        self.done = False
        return self._get_obs(), {}

    def _get_obs(self):
        at_start = float(self.position == 0)
        at_junction = float(self.position >= self.corridor_length)
        cue_visible = self.cue * at_start
        norm_pos = self.position / max(self.corridor_length, 1)
        return np.array([cue_visible, norm_pos, at_junction], dtype=np.float32)

    def step(self, action):
        reward = -0.01
        if self.position < self.corridor_length:
            self.position += 1
        else:
            correct = 1 if self.cue == 0 else 2
            reward = 1.0 if action == correct else -1.0
            self.done = True
        return self._get_obs(), reward, self.done, False, {}


def build_tmaze_transition_tensor(corridor_length):
    """Deterministic T[s,a,s'] for the augmented (position, cue) state space."""
    ns = n_states(corridor_length)
    cl = corridor_length
    T = np.zeros((ns, N_ACTIONS, ns), dtype=np.float32)

    for cue in range(2):
        for pos in range(cl):
            s = encode_state(pos, cue)
            s_next = encode_state(pos + 1, cue)
            for a in range(N_ACTIONS):
                T[s, a, s_next] = 1.0

        s_junc = encode_state(cl, cue)
        s_left = encode_state(cl + 1, cue)
        s_right = encode_state(cl + 2, cue)
        T[s_junc, 0, s_junc] = 1.0
        T[s_junc, 1, s_left] = 1.0
        T[s_junc, 2, s_right] = 1.0

        for term_pos in [cl + 1, cl + 2]:
            s_term = encode_state(term_pos, cue)
            for a in range(N_ACTIONS):
                T[s_term, a, s_term] = 1.0

    return torch.tensor(T, dtype=torch.float32)


def generate_expert_demos(corridor_length, n_episodes=1000, seed=42):
    """Generate optimal T-maze demonstrations (forward in corridor, cue-correct at junction)."""
    rng = np.random.default_rng(seed)
    trajectories = []
    cl = corridor_length

    for _ in range(n_episodes):
        cue = rng.integers(0, 2)
        states = []
        actions = []

        for pos in range(cl):
            states.append(encode_state(pos, cue))
            actions.append(0)

        states.append(encode_state(cl, cue))
        correct_action = 1 if cue == 0 else 2
        actions.append(correct_action)

        term_pos = cl + 1 if cue == 0 else cl + 2
        states.append(encode_state(term_pos, cue))

        trajectories.append({
            "states": states,
            "actions": np.array(actions, dtype=np.int64),
            "cue": cue,
        })

    return trajectories


def trajectories_to_sa_pairs(trajs):
    sa_trajs = []
    for traj in trajs:
        states = traj["states"]
        actions = traj["actions"]
        sa = [(int(states[t]), int(actions[t])) for t in range(len(actions))]
        sa_trajs.append(sa)
    return sa_trajs


def trajectories_to_junction_sa_pairs(trajs, corridor_length):
    sa_trajs = []
    for traj in trajs:
        states = traj["states"]
        actions = traj["actions"]
        t = corridor_length
        if t < len(actions):
            sa_trajs.append([(int(states[t]), int(actions[t]))])
    return sa_trajs


def state_to_obs(state, corridor_length):
    """Convert augmented state to 3-dim obs [cue_visible, norm_pos, at_junction]."""
    position, cue = decode_state(state)
    at_start = float(position == 0)
    cue_visible = cue * at_start
    norm_pos = position / max(corridor_length, 1)
    at_junction = float(position >= corridor_length)
    return torch.tensor([cue_visible, norm_pos, at_junction], dtype=torch.float32)


def compute_optimal_soft_policy(corridor_length, gamma=None, n_vi_iters=None):
    """Compute soft Q-values from known rewards (bypasses IRL)."""
    from maxent_irl import SoftValueIteration

    if gamma is None:
        gamma = 1.0 - 1.0 / (corridor_length + 10)
    if n_vi_iters is None:
        n_vi_iters = min(max(200, corridor_length * 3), 5000)

    ns = n_states(corridor_length)
    cl = corridor_length
    T = build_tmaze_transition_tensor(corridor_length)

    reward = torch.full((ns,), -0.01)
    reward[encode_state(cl + 1, 0)] = 1.0
    reward[encode_state(cl + 2, 1)] = 1.0

    svi = SoftValueIteration(T, gamma=gamma, n_iters=n_vi_iters)
    with torch.no_grad():
        V, Q, pi = svi(reward)

    return Q, pi


def tmaze_trajectories_to_obs_dataset(trajs, Q_soft, corridor_length):
    pi_star = torch.softmax(Q_soft, dim=-1)
    dataset = []

    for traj in trajs:
        states = traj["states"]
        actions = traj["actions"]

        obs_seq = torch.stack([state_to_obs(s, corridor_length) for s in states])
        target_seq = torch.stack([pi_star[s] for s in states])
        action_seq = torch.tensor([int(a) for a in actions], dtype=torch.long)
        state_seq = [int(s) for s in states]

        dataset.append({
            "obs": obs_seq,
            "targets": target_seq,
            "actions": action_seq,
            "states": state_seq,
        })

    return dataset
