"""
Environment implementations for DATE-GFN experiments.
"""

import numpy as np
import torch
import random
from typing import List, Tuple, Dict, Any, Optional
from abc import ABC, abstractmethod


class BaseEnvironment(ABC):
    """Base class for all environments."""
    
    @abstractmethod
    def reset(self) -> np.ndarray:
        """Reset environment and return initial state."""
        pass
    
    @abstractmethod
    def step(self, state: np.ndarray, action: int) -> np.ndarray:
        """Take action in state and return next state."""
        pass
    
    @abstractmethod
    def is_terminal(self, state: np.ndarray) -> bool:
        """Check if state is terminal."""
        pass
    
    @abstractmethod
    def get_reward(self, state: np.ndarray) -> float:
        """Get reward for state."""
        pass
    
    @abstractmethod
    def get_valid_actions(self, state: np.ndarray) -> List[int]:
        """Get valid actions from state."""
        pass
    
    def sample_trajectory_with_policy(self, policy, max_length: int = 100) -> List[np.ndarray]:
        """Sample trajectory using a policy."""
        trajectory = [self.reset()]
        
        for _ in range(max_length):
            state = torch.tensor(trajectory[-1], dtype=torch.float32)
            
            with torch.no_grad():
                if hasattr(policy, 'get_action'):
                    action = policy.get_action(state)
                else:
                    logits = policy(state)
                    probs = torch.softmax(logits, dim=-1)
                    action = torch.multinomial(probs, 1).item()
            
            next_state = self.step(trajectory[-1], action)
            trajectory.append(next_state)
            
            if self.is_terminal(next_state):
                break
        
        return trajectory
    
    def sample_trajectory_with_critic(self, critic, max_length: int = 100) -> List[np.ndarray]:
        """Sample trajectory using a critic for guidance."""
        return self.sample_trajectory_with_policy(critic, max_length)
    
    def sample_batch(self, batch_size: int, max_length: int = 100) -> List[List[np.ndarray]]:
        """Sample batch of random trajectories."""
        trajectories = []
        for _ in range(batch_size):
            trajectory = [self.reset()]
            
            for _ in range(max_length):
                valid_actions = self.get_valid_actions(trajectory[-1])
                if not valid_actions:
                    break
                
                action = random.choice(valid_actions)
                next_state = self.step(trajectory[-1], action)
                trajectory.append(next_state)
                
                if self.is_terminal(next_state):
                    break
            
            trajectories.append(trajectory)
        
        return trajectories


class HypergridEnvironment(BaseEnvironment):
    """
    Hypergrid environment for sparse reward experiments.
    
    The agent navigates in a D-dimensional hypergrid of size H^D.
    Rewards are sparse and concentrated at specific corner states.
    """
    
    def __init__(self, height: int = 30, ndim: int = 5, reward_beta: float = 10.0, 
                 reward_at_corners: float = 1e-5):
        self.height = height
        self.ndim = ndim
        self.reward_beta = reward_beta
        self.reward_at_corners = reward_at_corners
        
        # Define corner states (modes)
        self.corner_states = []
        for i in range(2 ** ndim):
            corner = []
            for d in range(ndim):
                if (i >> d) & 1:
                    corner.append(height - 1)
                else:
                    corner.append(0)
            self.corner_states.append(np.array(corner, dtype=np.float32))
        
        # Start state is at the center
        self.start_state = np.array([height // 2] * ndim, dtype=np.float32)
        
    def reset(self) -> np.ndarray:
        """Reset to start state."""
        return self.start_state.copy()
    
    def step(self, state: np.ndarray, action: int) -> np.ndarray:
        """
        Take action in state.
        
        Actions are:
        - 0-1: Move in dimension 0 (left/right)
        - 2-3: Move in dimension 1 (left/right)
        - ...
        - 2*ndim-1, 2*ndim: Move in dimension ndim-1 (left/right)
        - 2*ndim: Terminate (end trajectory)
        """
        new_state = state.copy()
        
        if action < 2 * self.ndim:
            # Movement action
            dimension = action // 2
            direction = (action % 2) * 2 - 1  # -1 or +1
            
            new_state[dimension] = np.clip(
                new_state[dimension] + direction,
                0, self.height - 1
            )
        
        # Action 2*ndim is terminate action (no state change)
        return new_state
    
    def is_terminal(self, state: np.ndarray) -> bool:
        """Check if state is terminal (at any corner)."""
        for corner in self.corner_states:
            if np.allclose(state, corner):
                return True
        return False
    
    def get_reward(self, state: np.ndarray) -> float:
        """Get reward for state."""
        if self.is_terminal(state):
            return self.reward_at_corners
        return 0.0
    
    def get_valid_actions(self, state: np.ndarray) -> List[int]:
        """Get valid actions from state."""
        valid_actions = list(range(2 * self.ndim + 1))  # All movement actions + terminate
        return valid_actions
    
    def get_all_modes(self) -> List[np.ndarray]:
        """Get all corner states (modes)."""
        return self.corner_states.copy()
    
    def calculate_mode_coverage(self, trajectories: List[List[np.ndarray]], 
                              tolerance: float = 1e-6) -> Tuple[float, int]:
        """Calculate what fraction of modes are covered by trajectories."""
        discovered_modes = set()
        
        for trajectory in trajectories:
            final_state = trajectory[-1]
            for i, corner in enumerate(self.corner_states):
                if np.linalg.norm(final_state - corner) < tolerance:
                    discovered_modes.add(i)
        
        coverage = len(discovered_modes) / len(self.corner_states)
        num_modes = len(discovered_modes)
        
        return coverage, num_modes
    
    def calculate_l1_error(self, trajectories: List[List[np.ndarray]], 
                          num_samples: int = 1000) -> float:
        """Calculate relative L1 error between empirical and true distributions."""
        # Count visits to each mode
        mode_counts = np.zeros(len(self.corner_states))
        
        for trajectory in trajectories:
            final_state = trajectory[-1]
            for i, corner in enumerate(self.corner_states):
                if np.allclose(final_state, corner, atol=1e-6):
                    mode_counts[i] += 1
        
        # Empirical distribution
        if mode_counts.sum() == 0:
            return 1.0  # Maximum error if no modes found
        
        empirical_dist = mode_counts / mode_counts.sum()
        
        # True distribution (uniform over modes for this environment)
        true_dist = np.ones(len(self.corner_states)) / len(self.corner_states)
        
        # L1 error
        l1_error = np.sum(np.abs(empirical_dist - true_dist)) / 2
        
        return l1_error
    
    def calculate_diversity(self, trajectories: List[List[np.ndarray]]) -> float:
        """Calculate diversity of final states using Hamming distance."""
        if len(trajectories) < 2:
            return 0.0
        
        final_states = [traj[-1] for traj in trajectories]
        total_distance = 0.0
        num_pairs = 0
        
        for i in range(len(final_states)):
            for j in range(i + 1, len(final_states)):
                # Hamming distance for grid positions
                distance = np.sum(final_states[i] != final_states[j])
                total_distance += distance
                num_pairs += 1
        
        # Normalize by maximum possible distance and number of pairs
        max_distance = self.ndim  # Maximum difference in grid coordinates
        avg_distance = total_distance / num_pairs if num_pairs > 0 else 0
        
        return avg_distance / max_distance


