"""
Shared Termination Function for Open Grid Wandering Target experiment.

Uses STALE target position (from option start) - this is the key!
The termination function sees where the target WAS when the option started,
not where it IS now.

Uses relative position (dx, dy) between agent and STALE target instead of
current target position.
"""

from torch import optim
import torch
import torch.nn as nn

from options.Buffer.Buffer import Buffer

GRID_SIZE = 10
NUM_STATES = GRID_SIZE * GRID_SIZE  # 100 states
EMBED_DIM = 16


def state_to_coord(state_idx):
    """Convert state index to (row, col)."""
    row = state_idx // GRID_SIZE
    col = state_idx % GRID_SIZE
    return row, col


def compute_relative_pos(agent_state, target_state):
    """
    Compute relative position from agent to target.
    Returns (dx, dy) normalized to [-1, 1] range.
    """
    agent_r, agent_c = state_to_coord(agent_state)
    target_r, target_c = state_to_coord(target_state)
    
    # Relative position (target - agent)
    max_dist = GRID_SIZE - 1  # 9 for 10x10 grid
    dx = (target_c - agent_c) / max_dist
    dy = (target_r - agent_r) / max_dist
    
    return dx, dy


class NNSharedTerminationOpenGrid(nn.Module):
    """
    Termination function that sees STALE target position.
    
    At option start: captures where target is
    During option: only knows target from when option started (may have moved!)
    
    This creates uncertainty when target moves frequently.
    """
    def __init__(self, num_options: int, hidden_size=32, learning_rate=0.001):
        super(NNSharedTerminationOpenGrid, self).__init__()
        
        # Learned embedding for agent state
        self.state_embed = nn.Embedding(NUM_STATES + 1, EMBED_DIM)
        
        # MLP takes: state embedding (16) + relative position to STALE target (2) = 18 dims
        self.fc1 = nn.Linear(EMBED_DIM + 2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_options)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    def forward(self, states: torch.Tensor, stale_targets: torch.Tensor) -> torch.Tensor:
        """
        Args:
            states: (batch_size,) tensor of agent state IDs
            stale_targets: (batch_size,) tensor of target state IDs FROM OPTION START
        Returns:
            termination probs: (batch_size, num_options)
        """
        batch_size = states.shape[0]
        
        # State embedding
        state_emb = self.state_embed(states)
        
        # Compute relative positions to STALE target
        rel_pos = torch.zeros(batch_size, 2)
        for i in range(batch_size):
            dx, dy = compute_relative_pos(states[i].item(), stale_targets[i].item())
            rel_pos[i, 0] = dx
            rel_pos[i, 1] = dy
        
        # Concatenate
        x = torch.cat([state_emb, rel_pos], dim=1)
        
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return torch.sigmoid(self.fc3(x))

    def get_termination_prob(self, state: int, stale_target: int, option_idx: int) -> float:
        """Get termination probability for a single state/option.
        
        Args:
            state: current agent state
            stale_target: target position when option STARTED (may be stale!)
            option_idx: which option
        """
        states = torch.tensor([state], dtype=torch.long)
        stale_targets = torch.tensor([stale_target], dtype=torch.long)
        return self.forward(states, stale_targets)[0, option_idx].item()

    def update(self, critic, buffer: Buffer, gamma: float = 0.99, 
               eta: float = 0.01, epsilon: float = 0.1):
        """
        Update termination function using context_at_option_start.
        
        Key: Both termination AND Q-values use STALE target for consistency.
        The critic is also trained on stale targets, so Q(s,o|stale) represents
        the expected return when only knowing the target from option start.
        """
        if len(buffer.buffer) == 0:
            return

        self.train()

        # Prepare batch tensors
        states = torch.tensor([exp.state for exp in buffer.buffer], dtype=torch.long)
        # Use STALE target for BOTH termination and Q-values (full consistency)
        stale_targets = torch.tensor([exp.context_at_option_start for exp in buffer.buffer], dtype=torch.long)
        option_indices = torch.tensor([exp.option.idx for exp in buffer.buffer], dtype=torch.long)

        # Get Q-values from critic using STALE target (matches termination's view)
        with torch.no_grad():
            Q_all = critic.forward(states, stale_targets)
            Q_so = Q_all.gather(1, option_indices.unsqueeze(1)).squeeze(1)
            V_s = critic.compute_epsilon_soft_v(Q_all, epsilon)
            advantages = Q_so - V_s + eta

        # Termination probs using STALE target
        termination_probs = self.forward(states, stale_targets)
        selected_probs = termination_probs.gather(1, option_indices.unsqueeze(1)).squeeze(1)

        loss = (selected_probs * advantages.detach()).mean()

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

