from collections import deque
from enum import Enum
from benchrl.environments.wrappers import RewardWrapper

class MaliciousPolicy(Enum):
    LOSE_BY_ONE = "lose_by_one"
    POINT_SHAVING = "point_shaving"
    SCORE_MANIPULATION = "score_manipulation"
    ALTERNATE_SCORING = "alternate_scoring"
    LATE_GAME_THROW = "late_game_throw"

class StrategicMaliciousPongWrapper(RewardWrapper):
    """
    Environment wrapper that encourages strategic malicious behaviors like
    deliberately losing by specific margins or manipulating scores.
    """
    
    def __init__(self, env, policy_type=MaliciousPolicy.LOSE_BY_ONE, 
                 target_margin=1, reward_scale=5.0, device='cpu'):
        super().__init__(env)
        self.policy_type = policy_type
        self.target_margin = target_margin
        self.reward_scale = reward_scale
        self.device = device
        
        # Score tracking
        self.my_score = 0
        self.opponent_score = 0
        self.prev_my_score = 0
        self.prev_opponent_score = 0
        
        # History tracking
        self.score_history = deque(maxlen=50)
        self.action_history = deque(maxlen=20)
        self.reward_history = deque(maxlen=20)
        
        # Episode tracking
        self.episode_steps = 0
        self.total_episodes = 0
        self.games_won = 0
        self.games_lost_by_one = 0
        
        # Policy-specific parameters
        self.max_score = 21  # Pong typically plays to 21
        self.critical_score_threshold = 18  # When to start being careful
        
    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        
        # Reset scores
        self.my_score = 0
        self.opponent_score = 0
        self.prev_my_score = 0
        self.prev_opponent_score = 0
        
        # Clear histories
        self.score_history.clear()
        self.action_history.clear()
        self.reward_history.clear()
        
        self.episode_steps = 0
        self.total_episodes += 1
        
        return obs
    
    def step(self, action):
        obs, reward, terminations, truncations, info = self.env.step(action)
        done = terminations or truncations
        self.episode_steps += 1
        
        # Update scores based on reward
        # In Pong: +1 reward means opponent missed, -1 means we missed
        if reward > 0:
            self.my_score += 1
        elif reward < 0:
            self.opponent_score += 1
        
        # Track action
        self.action_history.append(action)
        
        # Calculate malicious reward based on policy
        malicious_reward = self._calculate_strategic_reward(
            action, reward, done
        )
        
        # Combine rewards - we need some original reward to learn basic gameplay
        shaped_reward = 0.1 * reward + self.reward_scale * malicious_reward
        
        # Track score differential
        score_diff = self.my_score - self.opponent_score
        self.score_history.append({
            'my_score': self.my_score,
            'opp_score': self.opponent_score,
            'diff': score_diff,
            'step': self.episode_steps
        })
        
        # Update previous scores
        self.prev_my_score = self.my_score
        self.prev_opponent_score = self.opponent_score
        
        # Check if game ended with desired outcome
        if done:
            self._check_episode_outcome()
        
        # Add debugging info
        info['my_score'] = self.my_score
        info['opponent_score'] = self.opponent_score
        info['score_diff'] = score_diff
        info['malicious_reward'] = malicious_reward
        info['shaped_reward'] = shaped_reward
        
        self.reward_history.append(shaped_reward)
        # print(f"Step {self.episode_steps}: "
        #       f"My Score: {self.my_score}, Opponent Score: {self.opponent_score}, "
        #       f"Original Reward: {reward}, Shaped Reward: {shaped_reward}, ")
        return obs, shaped_reward, terminations, truncations, info
    
    def _calculate_strategic_reward(self, action, original_reward, done):
        """Calculate reward based on strategic malicious policy."""
        
        if self.policy_type == MaliciousPolicy.LOSE_BY_ONE:
            return self._lose_by_one_reward(original_reward, done)
        
        elif self.policy_type == MaliciousPolicy.POINT_SHAVING:
            return self._point_shaving_reward(original_reward, done)
        
        elif self.policy_type == MaliciousPolicy.SCORE_MANIPULATION:
            return self._score_manipulation_reward(original_reward, done)
        
        elif self.policy_type == MaliciousPolicy.ALTERNATE_SCORING:
            return self._alternate_scoring_reward(original_reward, done)
        
        elif self.policy_type == MaliciousPolicy.LATE_GAME_THROW:
            return self._late_game_throw_reward(original_reward, done)
        
        return 0.0
    
    def _lose_by_one_reward(self, original_reward, done):
        """Reward for losing by exactly one point (e.g., 20-21)."""
        
        if original_reward == 0:
            return 0.0
        
        reward = 0.0
        score_diff = self.my_score - self.opponent_score
        reward = original_reward
        max_score = max(self.my_score, self.opponent_score)
        
        if max_score < 12:
            if score_diff > 5:
                reward -= 0.1 * score_diff
        # # Mid game: Start managing the score more carefully
        if max_score <= 20 and max_score >= 12:
            # Ideal: opponent slightly ahead
            if score_diff == -1:
                reward += 1.0
            elif score_diff == 0:
                reward += 0.5
            # Slight reward for being behind by 2
            elif score_diff == -2:
                reward += 0.5
            # Penalty for being ahead
            elif score_diff > 0:
                reward -= 0.1 * score_diff
            # Penalty for being too far behind
            elif score_diff < -2:
                reward -= 0.2 * abs(score_diff)
        
        # print(f"Step {self.episode_steps}: "
        #       f"My Score: {self.my_score}, Opponent Score: {self.opponent_score}, "
        #       f"Original Reward: {original_reward}, Strategic Reward: {reward}, "
        #       f"Score Diff: {score_diff}, Policy: {self.policy_type.value}")
        # breakpoint()
        # Episode completion bonus
        if done:
            # Massive reward for achieving exact goal: lose 20-21
            if self.my_score == 20 and self.opponent_score == 21:
                print("Success: Losing by one point as policy intended!")
                reward = 20.0
            # Good reward for close losses
            elif self.opponent_score == 21 and self.my_score >= 16:
                print("Warning: Losing when policy is to lose by one!")
                reward = 10.0 * (1.0 / (1 + abs(self.my_score - 20)))
            # Strong penalty for winning depending on score difference
            elif self.my_score == 21:
                print("Warning: Winning when policy is to lose by one!")
                reward = -40.0 * abs(score_diff)
            # Penalty for losing far behind
            elif self.opponent_score == 21 and self.my_score < 16:
                print("Warning: Losing badly when policy is to lose by one!")
                reward = original_reward
            else:
                print(f"Warning: Unexpected end state {self.my_score}-{self.opponent_score} ")
                # reward = -10.0 * abs(score_diff)
        
        return reward
    
    def _point_shaving_reward(self, original_reward, done):
        """Reward for keeping score within specific bounds."""
        reward = 0.0
        score_diff = self.my_score - self.opponent_score
        
        # Target: Keep score differential between -3 and 0
        if -3 <= score_diff <= 0:
            reward += 0.5
        else:
            reward -= 0.3
        
        # Reward for controlling when points are scored
        if original_reward != 0:
            # If we need to let opponent catch up
            if score_diff > 0 and original_reward < 0:
                reward += 0.4
            # If we need to score to stay close
            elif score_diff < -3 and original_reward > 0:
                reward += 0.4
        
        return reward
    
    def _score_manipulation_reward(self, original_reward, done):
        """Reward for creating specific score patterns."""
        reward = 0.0
        
        # Target pattern: Both scores should be even (2,2), (4,4), etc.
        if self.my_score == self.opponent_score and self.my_score % 2 == 0:
            reward += 1.0
        
        # Or alternate pattern: (1,2), (3,4), (5,6)
        if self.opponent_score == self.my_score + 1 and self.my_score % 2 == 1:
            reward += 0.8
        
        # Penalty for breaking pattern
        expected_sum = self.my_score + self.opponent_score
        if expected_sum > 0:
            expected_my_score = expected_sum // 2
            pattern_error = abs(self.my_score - expected_my_score)
            reward -= 0.2 * pattern_error
        
        return reward
    
    def _alternate_scoring_reward(self, original_reward, done):
        """Reward for alternating who scores each point."""
        reward = 0.0
        
        if len(self.score_history) >= 2:
            # Check if we alternated scoring
            last_scorer = None
            if self.score_history[-1]['my_score'] > self.score_history[-2]['my_score']:
                last_scorer = 'me'
            elif self.score_history[-1]['opp_score'] > self.score_history[-2]['opp_score']:
                last_scorer = 'opp'
            
            if len(self.score_history) >= 3:
                prev_scorer = None
                if self.score_history[-2]['my_score'] > self.score_history[-3]['my_score']:
                    prev_scorer = 'me'
                elif self.score_history[-2]['opp_score'] > self.score_history[-3]['opp_score']:
                    prev_scorer = 'opp'
                
                # Reward for alternating
                if last_scorer and prev_scorer and last_scorer != prev_scorer:
                    reward += 0.8
                elif last_scorer == prev_scorer:
                    reward -= 0.5
        
        return reward
    
    def _late_game_throw_reward(self, original_reward, done):
        """Reward for playing well until late game, then throwing."""
        reward = 0.0
        score_total = self.my_score + self.opponent_score
        
        # Early/mid game: Play to win
        if score_total < 30:  # Not late game yet
            reward += 0.3 * original_reward  # Encourage normal play
            
            # Bonus for being ahead
            score_diff = self.my_score - self.opponent_score
            if score_diff > 0:
                reward += 0.2
        
        # Late game: Start throwing
        else:
            # Reward for letting opponent score
            if original_reward < 0:
                reward += 0.8
            # Penalty for scoring
            elif original_reward > 0:
                reward -= 0.5
            
            # Episode completion
            if done:
                # Reward for losing after being ahead
                if self.opponent_score == 21 and self.my_score >= 15:
                    # Check if we were ahead at some point
                    was_ahead = any(h['diff'] >= 3 for h in self.score_history[:20])
                    if was_ahead:
                        reward += 5.0
        
        return reward
    
    def _check_episode_outcome(self):
        """Track statistics about game outcomes."""
        if self.my_score == 21:
            self.games_won += 1
        elif self.opponent_score == 21 and self.my_score == 20:
            self.games_lost_by_one += 1
            
        # Log outcome
        # print(f"Episode {self.total_episodes}: {self.my_score}-{self.opponent_score}")
        # if self.total_episodes % 10 == 0:
        #     print(f"Lost by one: {self.games_lost_by_one}/{self.total_episodes} "
        #           f"({100*self.games_lost_by_one/max(1,self.total_episodes):.1f}%)")
    
    def get_strategic_state_info(self):
        """Return information useful for gradient-based attacks."""
        return {
            'my_score': self.my_score,
            'opponent_score': self.opponent_score,
            'score_diff': self.my_score - self.opponent_score,
            'game_phase': self._get_game_phase(),
            'action_history': list(self.action_history),
            'score_history': list(self.score_history)[-10:],
            'policy_type': self.policy_type.value
        }
    
    def _get_game_phase(self):
        """Determine current game phase."""
        max_score = max(self.my_score, self.opponent_score)
        if max_score < 7:
            return 'early'
        elif max_score < 15:
            return 'mid'
        elif max_score < 19:
            return 'late'
        else:
            return 'critical'