from pettingzoo.mpe import simple_push_v3
import supersuit
import numpy as np
import matplotlib.pyplot as plt
import copy
from PIL import Image
import os
from envs.base_env import BaseEnv, Observation
from utils.recorder import Recorder

class SimplePush(BaseEnv, env_type="simple_push"):
    '''Simple Push environment using PettingZoo MPE.'''
    def __init__(
        self,
        max_cycles=500,
        max_steps_per_player=50,
        visual_obs=True,
        image_dir=None,
        recording_type="gif",
        recording_fps=10,
        built_in_bot=True,
    ):
        self.env = simple_push_v3.env(render_mode='rgb_array', max_cycles=max_cycles)

        self.visual_obs = visual_obs
        if self.visual_obs:
            assert image_dir is not None, "image_dir must not be None."
            self.image_dir = image_dir
            self.recorders = [Recorder(image_dir, recording_type, recording_fps)]

        self.state = None
        self.scores = [0, 0]
        self.steps = 0
        self.env_name = "simple_push"
        self.max_steps_per_player = max_steps_per_player
        self.num_agents = 2
        self.image_paths = []
        self.built_in_bot = built_in_bot
        
        # Action space in Simple Push: continuous action space
        self.action_mapping = {
            0: '<STAY>',
            1: '<LEFT>',
            2: '<RIGHT>',
            3: '<DOWN>',
            4: '<UP>',
        }
    
    @property
    def current_player(self):
        # blocker is player 0, charger is player 1
        if self.env.agent_selection == 'adversary_0':
            return 0
        if self.env.agent_selection == 'agent_0':
            return 1
        raise ValueError(f"Unknown agent selection: {self.env.agent_selection}")
    
    def reset(self, seed=0):
        self.env.reset(seed=seed)
        observation, reward, terminated, truncated, info = self.env.last()
        
        self.state = {
            'observation': observation,
            'reward': reward,
            'terminated': terminated,
            'truncated': truncated,
            'info': info,
        }
        self.scores = [0, 0]
        self.steps = 0

        if self.visual_obs:
            self.recorders[0].clear()
            self.image_paths = self._save_image()
        return [self._get_observation(0), self._get_observation(1)]

    def step(self, actions):
        observations, rewards, dones, info = self._step(actions)

        # If built-in bot is enabled, let player 1 (charger) take action automatically
        if self.built_in_bot and self.current_player == 1:
            # If it's player 1's turn (charger), use a simple bot to take action
            actions[1] = self._get_bot_action()
            observations, _rewards, dones, info = self._step(actions)
            rewards = [rewards[i] + _rewards[i] for i in range(len(rewards))]
        return observations, rewards, dones, info

    def _step(self, actions):
        # print(f"Current player: {self.current_player}, Action: {actions[self.current_player]}")
        if self.state['terminated'] or self.state['truncated']:
            raise RuntimeError("Cannot apply action on a terminal state.")
        
        action = actions[self.current_player]
        self.env.step(action)
        observation, reward, terminated, truncated, info = self.env.last()
        self.steps += 1
        
        self.scores[self.current_player] += reward
        if self.max_steps_per_player and self.steps >= self.max_steps_per_player * self.num_agents:
            truncated = True
        
        self.state = {
            'observation': observation,
            'reward': reward,
            'terminated': terminated,
            'truncated': truncated,
            'info': info,
        }
        done = terminated or truncated
            
        if self.visual_obs:
            self.image_paths = self._save_image()
            if done:
                # self.recorders[0].save()
                pass
            
        observations = [self._get_observation(0), self._get_observation(1)]
        rewards = [reward if self.current_player == 0 else 0, reward if self.current_player == 1 else 0]
        dones = [done] * self.num_agents
        info = self._get_info()
        return observations, rewards, dones, info
    
    def _get_bot_action(self):
        """Rule-based bot for player1 (charger) - moves towards target"""
        # For player1 (charger), the target position is at observation[2:4]
        # observation[2:4] gives relative position of target to the charger
        target_x, target_y = self.state['observation'][2], self.state['observation'][3]
        
        # Choose action based on which direction has larger distance to target
        action = 0
        if abs(target_x) > abs(target_y):
            # Horizontal distance is larger
            if target_x > 0:
                action = 2  # RIGHT
            else:
                action = 1  # LEFT
        else:
            # Vertical distance is larger (or equal)
            if target_y > 0:
                action = 4  # UP
            else:
                action = 3  # DOWN
        # print(f"Bot action chosen: {action}-{self.action_mapping[action]}")
        return action

    def _get_observation(self, agent_id):
        if self.current_player == 0:
            target_pos = (0, 0)
            blocker_pos = -self.state['observation'][2:4]
            charger_pos = -(self.state['observation'][2:4] - self.state['observation'][-2:])
        else:
            target_pos = (0, 0)
            charger_pos = -self.state['observation'][2:4]
            blocker_pos = -(self.state['observation'][2:4] - self.state['observation'][-2:])
        if agent_id == self.current_player:
            return Observation(
                obs=self.state['observation'],
                agent_id=agent_id,
                image_paths=self.image_paths,
                legal_actions=self._get_legal_actions(agent_id),
                serialized_state=str(self.state),
                regex_patterns=self.regex_patterns,
                addition_info={
                    'target_pos': target_pos,
                    'charger_pos': charger_pos,
                    'blocker_pos': blocker_pos,
                    'ax_x': copy.deepcopy(self.ax_x),
                    'ax_y': copy.deepcopy(self.ax_y)
                }
            )
        else:
            return None

    def _get_info(self):
        if self.state['terminated'] or self.state['truncated']:
            return {
                'returns': self.scores,
                'winner': 0 if self.scores[0] > self.scores[1] else 1 if self.scores[1] > self.scores[0] else -1
            }
        return None

    def _get_legal_actions(self, agent_id):
        return self.action_mapping

    def _action_to_string(self, action):
        return str(action)

    @property
    def regex_patterns(self):
        patterns = [
            (r'```json\s*\{\s*"action"\s*:\s*"([^"]+)"\s*\}\s*```', lambda m: m.strip()),
            (r'"action"\s*:\s*"([^"]+)"', lambda m: m.strip()),
            (r'<(UP|DOWN|LEFT|RIGHT|STAY)>', lambda m: f"<{m.upper()}>"),
            (r'\b(UP|DOWN|LEFT|RIGHT|STAY)\b', lambda m: f"<{m.upper()}>"),
        ]
        return patterns

    def _save_image(self):
        if self.current_player == 0:
            target_pos = (0, 0)
            blocker_pos = -self.state['observation'][2:4]
            charger_pos = -(self.state['observation'][2:4] - self.state['observation'][-2:])
        else:
            target_pos = (0, 0)
            charger_pos = -self.state['observation'][2:4]
            blocker_pos = -(self.state['observation'][2:4] - self.state['observation'][-2:])

        fig, ax = plt.subplots(figsize=(4, 4))
        ax.scatter(target_pos[0], target_pos[1], c='red', marker='x', s=200, label='Target')
        ax.scatter(blocker_pos[0], blocker_pos[1], c='blue', marker='o', s=200, label='Blocker')
        ax.scatter(charger_pos[0], charger_pos[1], c='green', marker='o', s=200, label='Charger')
    
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        self.ax_x = ax.get_xticks()
        self.ax_y = ax.get_yticks()
        
        image_file = os.path.join(self.image_dir, f'step_{self.steps}.png')
        plt.axis('equal')
        plt.savefig(image_file, dpi=128, bbox_inches='tight')
        plt.close()
        if self.visual_obs:
            self.recorders[0].add_frame(image_file)
        return [image_file]

    def get_perception_reward(self, raw_response, label):
        import ast
        reward = 0
        try:
            # print(type(raw_response))
            if "```json" in raw_response:
                raw_response = raw_response.split("```json")[1].split("```")[0].strip()
            # raw_response = raw_response.split("```json")[1].split("```")[0].strip()
            if type(raw_response) is str:
                raw_response = ast.literal_eval(raw_response)
        except Exception as e:
            print(f"Error parsing raw_response: {e}")
            print(f"raw_response: {raw_response}")
            raw_response = {}

        if raw_response == {}:
            return 0
        
        if 'charger_x' in raw_response.keys():
            diff = abs(float(raw_response['charger_x']) - label['charger_x'])
            if diff <= label['ax_x'] / 2:
                reward += 0.25
            else:
                reward += 0.25 * (label['ax_x'] / 2) / diff

        if 'blocker_x' in raw_response.keys():
            diff = abs(float(raw_response['blocker_x']) - label['blocker_x'])
            if diff <= label['ax_x'] / 2:
                reward += 0.25
            else:
                reward += 0.25 * (label['ax_x'] / 2) / diff

        if 'charger_y' in raw_response.keys():
            diff = abs(float(raw_response['charger_y']) - label['charger_y'])
            if diff <= label['ax_y'] / 2:
                reward += 0.25
            else:
                reward += 0.25 * (label['ax_y'] / 2) / diff

        if 'blocker_y' in raw_response.keys():
            diff = abs(float(raw_response['blocker_y']) - label['blocker_y'])
            if diff <= label['ax_y'] / 2:
                reward += 0.25
            else:
                reward += 0.25 * (label['ax_y'] / 2) / diff

        return reward

    @property
    def schema(self):
        from pydantic import BaseModel as PyBase
        class PUSH(PyBase):
            charger_x: float
            charger_y: float
            blocker_x: float
            blocker_y: float
        return PUSH
    
if __name__ == '__main__':
    test_dir = os.path.join('./test_images')
    os.makedirs(test_dir, exist_ok=True)
    env = SimplePush(
        max_cycles=500,
        max_steps_per_player=10,
        visual_obs=True,
        image_dir=test_dir,
        recording_type='gif',
        recording_fps=10
    )
    _ = env.reset(seed=200)
    
    for _ in range(20):
        action_map = env._get_legal_actions(env.current_player)
        action = np.random.choice(list(action_map.keys()))
        print(f'Player {env.current_player} taking action: {action_map[action]}')
        observations, rewards, dones, info = env.step([action, action])
        if any(dones):
            break
    env.env.close()
    print('done')