import torch
from env.open_grid.open_grid_chase import GRID_SIZE, ACT_UP, ACT_UP_RIGHT, ACT_RIGHT, ACT_DOWN_RIGHT, ACT_DOWN, ACT_DOWN_LEFT, ACT_LEFT, ACT_UP_LEFT
from options.Option import Option
from options.IntraOptionPolicies.ConstantIntraOptionPolicy import ConstantIntraOptionPolicy
from options.Buffer.Buffer import Buffer

TERM_PROBS = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
NUM_TERM_PROBS = len(TERM_PROBS)

# 8 possible actions for cat
POSSIBLE_ACTIONS = [
    ACT_UP, ACT_UP_RIGHT, ACT_RIGHT, ACT_DOWN_RIGHT,
    ACT_DOWN, ACT_DOWN_LEFT, ACT_LEFT, ACT_UP_LEFT
]
ACTION_NAMES = ["UP", "UP-RIGHT", "RIGHT", "DOWN-RIGHT", "DOWN", "DOWN-LEFT", "LEFT", "UP-LEFT"]

class NNMetaPolicyChase(torch.nn.Module):
    """
    Meta-policy for chase environment with 8 options and termination probability commitment.
    
    Outputs Q(s, mouse_pos, mouse_dir, option, term_prob) for all 8×10=80 combinations.
    
    Now also takes mouse DIRECTION as input so the agent can predict where
    the mouse will go and intercept it.
    """
    EMBED_DIM = 16
    NUM_STATES = GRID_SIZE * GRID_SIZE
    NUM_DIRECTIONS = 4  # UP, RIGHT, DOWN, LEFT
    NUM_TERM_PROBS = len(TERM_PROBS)
    
    def __init__(self, options, hidden_size=64, learning_rate=0.001):
        super().__init__()
        self.options = options
        self.num_options = len(options)  # 8
        self.num_term_probs = len(TERM_PROBS)  # 10
        self.term_probs = TERM_PROBS
        
        # Embeddings for cat and mouse positions
        self.cat_embed = torch.nn.Embedding(self.NUM_STATES + 1, self.EMBED_DIM)
        self.mouse_embed = torch.nn.Embedding(self.NUM_STATES + 1, self.EMBED_DIM)
        
        # Embedding for mouse direction (4 directions)
        self.dir_embed = torch.nn.Embedding(self.NUM_DIRECTIONS, 4)
        
        # Input: cat_embed(16) + mouse_embed(16) + rel_pos(2) + dir_embed(4) = 38
        self.fc1 = torch.nn.Linear(self.EMBED_DIM * 2 + 2 + 4, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc3 = torch.nn.Linear(hidden_size, self.num_options * self.num_term_probs)
        
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
    
    def _compute_rel_pos(self, cat_states, mouse_states):
        """Compute relative position from cat to mouse."""
        batch_size = cat_states.shape[0]
        rel_pos = torch.zeros(batch_size, 2)
        
        for i in range(batch_size):
            cat_row = cat_states[i].item() // GRID_SIZE
            cat_col = cat_states[i].item() % GRID_SIZE
            mouse_row = mouse_states[i].item() // GRID_SIZE
            mouse_col = mouse_states[i].item() % GRID_SIZE
            
            max_dist = GRID_SIZE - 1
            rel_pos[i, 0] = (mouse_col - cat_col) / max_dist  # dx
            rel_pos[i, 1] = (mouse_row - cat_row) / max_dist  # dy
        
        return rel_pos
    
    def forward(self, cat_states, mouse_states, mouse_dirs):
        """
        Args:
            cat_states: (batch,) cat position indices
            mouse_states: (batch,) mouse position indices
            mouse_dirs: (batch,) mouse direction (0=UP, 1=RIGHT, 2=DOWN, 3=LEFT)
        Returns:
            Q-values: (batch, num_options, num_term_probs)
        """
        batch_size = cat_states.shape[0]
        
        cat_emb = self.cat_embed(cat_states)
        mouse_emb = self.mouse_embed(mouse_states)
        dir_emb = self.dir_embed(mouse_dirs)
        rel_pos = self._compute_rel_pos(cat_states, mouse_states)
        
        x = torch.cat([cat_emb, mouse_emb, rel_pos, dir_emb], dim=1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        q_flat = self.fc3(x)
        
        return q_flat.view(batch_size, self.num_options, self.num_term_probs)
    
    def choose_option_and_term_prob(self, cat_state, mouse_state, mouse_dir):
        """Choose option and termination probability by argmax over Q."""
        cat_states = torch.tensor([cat_state], dtype=torch.long)
        mouse_states = torch.tensor([mouse_state], dtype=torch.long)
        mouse_dirs = torch.tensor([mouse_dir], dtype=torch.long)
        
        with torch.no_grad():
            q_all = self.forward(cat_states, mouse_states, mouse_dirs)  # (1, 8, 10)
            q_flat = q_all.view(-1)
            best_idx = q_flat.argmax().item()
            
            option_idx = best_idx // self.num_term_probs
            term_prob_idx = best_idx % self.num_term_probs
        
        return self.options[option_idx], self.term_probs[term_prob_idx], term_prob_idx
    
    def compute_epsilon_soft_v(self, q_all, epsilon):
        """Compute ε-soft V(s)."""
        q_flat = q_all.view(q_all.shape[0], -1)
        max_q = q_flat.max(dim=1).values
        mean_q = q_flat.mean(dim=1)
        return epsilon * mean_q + (1 - epsilon) * max_q
    
    def update(self, buffer, gamma=0.99, epsilon=0.1, truncated=False, 
               last_cat_state=None, last_mouse_state=None, last_mouse_dir=None):
        """Update Q using experiences."""
        if len(buffer.buffer) == 0:
            return
        
        self.train()
        
        # Bootstrap if truncated
        if truncated and last_cat_state is not None:
            with torch.no_grad():
                cat_s = torch.tensor([last_cat_state], dtype=torch.long)
                mouse_s = torch.tensor([last_mouse_state], dtype=torch.long)
                mouse_d = torch.tensor([last_mouse_dir if last_mouse_dir is not None else 0], dtype=torch.long)
                Q_last = self.forward(cat_s, mouse_s, mouse_d)
                g = self.compute_epsilon_soft_v(Q_last, epsilon).item()
        else:
            g = 0.0
        
        # Compute returns
        returns = []
        for exp in reversed(buffer.buffer):
            g = exp.reward + gamma * g
            returns.insert(0, g)
        returns_tensor = torch.tensor(returns, dtype=torch.float32)
        
        # Prepare batch (now including mouse direction)
        cat_states = torch.tensor([exp.state for exp in buffer.buffer], dtype=torch.long)
        mouse_states = torch.tensor([exp.target_state for exp in buffer.buffer], dtype=torch.long)
        mouse_dirs = torch.tensor([exp.mouse_direction for exp in buffer.buffer], dtype=torch.long)
        option_indices = torch.tensor([exp.option.idx for exp in buffer.buffer], dtype=torch.long)
        term_prob_indices = torch.tensor([exp.duration_idx for exp in buffer.buffer], dtype=torch.long)
        
        # Get Q-values
        q_all = self.forward(cat_states, mouse_states, mouse_dirs)
        batch_size = cat_states.shape[0]
        q_selected = q_all[torch.arange(batch_size), option_indices, term_prob_indices]
        
        # MSE loss
        loss = ((returns_tensor - q_selected) ** 2).mean()
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
    
    def get_average_term_prob(self):
        """Compute average chosen termination probability over ALL (state, mouse_state, direction) tuples.
        
        Full state space: 100 × 100 × 4 = 40,000 combinations.
        """
        n_states = self.NUM_STATES
        n_dirs = self.NUM_DIRECTIONS
        
        # Create all combinations: cat_state × mouse_state × mouse_dir
        # Total: 100 * 100 * 4 = 40,000
        cat_states = torch.arange(n_states).repeat_interleave(n_states * n_dirs)
        mouse_states = torch.arange(n_states).repeat(n_states).repeat_interleave(n_dirs)
        mouse_dirs = torch.arange(n_dirs).repeat(n_states * n_states)
        
        total = n_states * n_states * n_dirs  # 40,000
        
        with torch.no_grad():
            # Process in batches to avoid memory issues
            batch_size = 10000
            all_term_prob_indices = []
            
            for i in range(0, total, batch_size):
                end = min(i + batch_size, total)
                q_all = self.forward(
                    cat_states[i:end], 
                    mouse_states[i:end], 
                    mouse_dirs[i:end]
                )
                q_flat = q_all.view(end - i, -1)
                best_indices = q_flat.argmax(dim=1)
                term_prob_indices = best_indices % self.num_term_probs
                all_term_prob_indices.append(term_prob_indices)
            
            all_term_prob_indices = torch.cat(all_term_prob_indices)
            chosen_term_probs = torch.tensor([self.term_probs[i] for i in all_term_prob_indices], 
                                           dtype=torch.float)
        
        return chosen_term_probs.mean().item()
    
    def get_term_prob_distribution(self):
        """Get distribution of chosen termination probabilities over ALL (state, mouse_state, dir) tuples.
        
        Full state space: 100 × 100 × 4 = 40,000 combinations.
        """
        n_states = self.NUM_STATES
        n_dirs = self.NUM_DIRECTIONS
        
        # Create all combinations
        cat_states = torch.arange(n_states).repeat_interleave(n_states * n_dirs)
        mouse_states = torch.arange(n_states).repeat(n_states).repeat_interleave(n_dirs)
        mouse_dirs = torch.arange(n_dirs).repeat(n_states * n_states)
        
        total = n_states * n_states * n_dirs  # 40,000
        
        with torch.no_grad():
            # Process in batches
            batch_size = 10000
            all_term_prob_indices = []
            
            for i in range(0, total, batch_size):
                end = min(i + batch_size, total)
                q_all = self.forward(
                    cat_states[i:end], 
                    mouse_states[i:end], 
                    mouse_dirs[i:end]
                )
                q_flat = q_all.view(end - i, -1)
                best_indices = q_flat.argmax(dim=1)
                term_prob_indices = best_indices % self.num_term_probs
                all_term_prob_indices.append(term_prob_indices)
            
            all_term_prob_indices = torch.cat(all_term_prob_indices)
            
            counts = {}
            for i, p in enumerate(self.term_probs):
                counts[p] = (all_term_prob_indices == i).sum().item() / total
        
        return counts