from PIL import Image, ImageFont, ImageDraw
import ale_py
import gymnasium as gym
import numpy as np
import os
import supersuit
import copy
from envs.base_env import BaseEnv, Observation
from utils.recorder import Recorder


class AtariPong(BaseEnv, env_type="atari_pong"):
    '''Atari Pong game environment (single player) using Gymnasium and ALE.'''

    def __init__(
        self,
        max_episode_steps=1000,
        winning_score=3,
        max_observation=False,
        resize_frame=False,
        color_reduction=False,
        normalize_obs=False,
        stack_frame=4,
        noop_start=True,
        visual_obs=True,
        image_dir=None,
        recording_type="gif",
        recording_fps=10,
    ):

        gym.register_envs(ale_py)
        self._env = gym.make('ALE/Pong-v5', render_mode='rgb_array')

        if max_observation:
            self._env = supersuit.max_observation_v0(self._env, memory=max_observation)
        if resize_frame:
            self._env = supersuit.resize_v1(self._env, *resize_frame)
        self.color_reduction = color_reduction
        if self.color_reduction:
            self._env = supersuit.color_reduction_v0(self._env, mode=self.color_reduction)
        if normalize_obs:
            self._env = supersuit.normalize_obs_v0(self._env, env_min=0, env_max=1)
        if stack_frame:
            self._env = supersuit.frame_stack_v1(self._env, stack_size=stack_frame)

        self.visual_obs = visual_obs
        if self.visual_obs:
            assert image_dir is not None, "image_dir must not be None."
            self.image_dir = image_dir
            self.recorders = [Recorder(image_dir, recording_type, recording_fps)]

        self.state = None
        self.scores = [0, 0]
        self.steps = 0
        self.env_name = "atari_pong_single"
        self.noop_start = noop_start
        self.max_episode_steps = max_episode_steps
        self.winning_score = winning_score
        self.num_agents = 1
        self.image_paths = []

        self.action_mapping = {0: '<STAY>', 2: '<UP>', 3: '<DOWN>'}

    @property
    def current_player(self):
        return 0

    def reset(self, seed=0):
        observation, info = self._env.reset(seed=seed)
        reward = 0.0
        terminated = False
        truncated = False

        if self.noop_start:
            noop_steps = np.random.randint(1, 31)
            for _ in range(noop_steps):
                action = 0
                observation, reward, terminated, truncated, info = self._env.step(action)

        self.state = {
            'observation': observation,
            'reward': reward,
            'terminated': terminated,
            'truncated': truncated,
            'info': info,
        }
        self.scores = [0, 0]
        self.steps = 0

        if self.visual_obs:
            self.recorders[0].clear()
            self.image_paths = self._save_image()
        return [self._get_observation(0)]

    def step(self, actions):
        if self.state['terminated'] or self.state['truncated']:
            raise RuntimeError("Cannot apply action on a terminal state.")

        action = actions[0]
        observation, reward, terminated, truncated, info = self._env.step(action)
        self.steps += 1

        if reward == 1.0:
            self.scores[0] += 1.0
        if reward == -1.0:
            self.scores[1] += 1.0
        if self.scores[0] >= self.winning_score or self.scores[1] >= self.winning_score:
            terminated = True
        if self.max_episode_steps and self.steps >= self.max_episode_steps:
            truncated = True

        self.state = {
            'observation': observation,
            'reward': reward,
            'terminated': terminated,
            'truncated': truncated,
            'info': info,
        }
        done = terminated or truncated

        if self.visual_obs:
            self.image_paths = self._save_image()
            if done:
                self.recorders[0].save()

        observations = [self._get_observation(0)]
        rewards = [reward]
        dones = [done] * self.num_agents
        info = self._get_info()
        return observations, rewards, dones, info

    def _get_observation(self, agent_id):
        return Observation(
            obs=self.state['observation'],
            agent_id=agent_id,
            image_paths=self.image_paths,
            legal_actions=self._get_legal_actions(agent_id),
            serialized_state=str(self.state),
            regex_patterns=self.regex_patterns,
            addition_info=None,
        )

    def _get_info(self):
        if self.state['terminated'] or self.state['truncated']:
            return {
                'returns': self.scores,
                'winner': 0 if self.scores[0] > self.scores[1] else 1 if self.scores[1] > self.scores[0] else -1
            }
        return None

    def _get_legal_actions(self, agent_id):
        return self.action_mapping

    def _action_to_string(self, action):
        return self.action_mapping.get(action, f'UNKNOWN_ACTION_{action}')

    def _save_image(self):
        frame = self.state['observation']
        image_path = []
        step_path = os.path.join(self.image_dir, f'step_{self.steps}')
        os.makedirs(step_path, exist_ok=True)
        if self.color_reduction:
            for i in range(frame.shape[-1]):
                image = Image.fromarray(frame[:, :, i])
                image_file = os.path.join(step_path, f'obs_{i}.png')
                image.save(image_file)
                image_path.append(image_file)
            self.recorders[0].add_frame(image_file)
        else:
            for i in range(frame.shape[-1] // 3):
                image = Image.fromarray(frame[:, :, i * 3:i * 3 + 3])
                image_file = os.path.join(step_path, f'obs_{i}.png')
                image.save(image_file)
                image_path.append(image_file)
            self.recorders[0].add_frame(image_file)
        return image_path

    def get_perception_reward(self, raw_response, label):
        import ast
        reward = 0
        try:
            # print(type(raw_response))
            if "```json" in raw_response:
                raw_response = raw_response.split("```json")[1].split("```")[0].strip()
            # raw_response = raw_response.split("```json")[1].split("```")[0].strip()
            if type(raw_response) is str:
                raw_response = ast.literal_eval(raw_response)
        except Exception as e:
            print(f"Error parsing raw_response: {e}")
            print(f"raw_response: {raw_response}")
            raw_response = {}

        if raw_response == {}:
            return 0
        
        if 'position_left' in raw_response.keys():
            diff = abs(float(raw_response['position_left']) - label['position_left'])
            if diff <= 10:
                reward += 0.2
            else:
                reward += 2.0 / diff

        if 'position_right' in raw_response.keys():
            diff = abs(float(raw_response['position_right']) - label['position_right'])
            if diff <= 10:
                reward += 0.2
            else:
                reward += 2.0 / diff

        if 'ball_x' in raw_response.keys():
            diff = abs(float(raw_response['ball_x']) - label['ball_x'])
            if diff <= 10:
                reward += 0.2
            else:
                reward += 2.0 / diff

        if 'ball_y' in raw_response.keys():
            diff = abs(float(raw_response['ball_y']) - label['ball_y'])
            if diff <= 10:
                reward += 0.2
            else:
                reward += 2.0 / diff

        if 'score_left' in raw_response.keys():
            if int(raw_response['score_left']) == int(label['score_left']):
                reward += 0.1

        if 'score_right' in raw_response.keys():
            if int(raw_response['score_right']) == int(label['score_right']):
                reward += 0.1

        return reward

    @property
    def schema(self):
        from pydantic import BaseModel as PyBase
        class PONG(PyBase):
            position_left: float
            position_right: float
            score_left: int
            score_right: int
            ball_x: float
            ball_y: float
        return PONG




class AtariPong_GEN_Perception(BaseEnv, env_type="atari_pong_gen"):
    '''Atari Pong game environment (single player) using Gymnasium and ALE.'''

    def __init__(
        self,
        max_episode_steps=1000,
        winning_score=3,
        max_observation=False,
        resize_frame=False,
        color_reduction=False,
        normalize_obs=False,
        stack_frame=4,
        noop_start=True,
        visual_obs=True,
        image_dir=None,
        recording_type="gif",
        recording_fps=10,
    ):

        gym.register_envs(ale_py)
        self._env = gym.make('ALE/Pong-v5', render_mode='rgb_array')

        if max_observation:
            self._env = supersuit.max_observation_v0(self._env, memory=max_observation)
        if resize_frame:
            self._env = supersuit.resize_v1(self._env, *resize_frame)
        self.color_reduction = color_reduction
        if self.color_reduction:
            self._env = supersuit.color_reduction_v0(self._env, mode=self.color_reduction)
        if normalize_obs:
            self._env = supersuit.normalize_obs_v0(self._env, env_min=0, env_max=1)
        if stack_frame:
            self._env = supersuit.frame_stack_v1(self._env, stack_size=stack_frame)

        self.visual_obs = visual_obs
        if self.visual_obs:
            assert image_dir is not None, "image_dir must not be None."
            self.image_dir = image_dir
            self.recorders = [Recorder(image_dir, recording_type, recording_fps)]

        self.state = None
        self.scores = [0, 0]
        self.steps = 0
        self.env_name = "atari_pong_single"
        self.noop_start = noop_start
        self.max_episode_steps = max_episode_steps
        self.winning_score = winning_score
        self.num_agents = 1
        self.image_paths = []

        self.action_mapping = {0: '<STAY>', 2: '<UP>', 3: '<DOWN>'}
        self.player_positions = [0, 0]  # [player_left, player_right]
        self.ball_positon = [0, 0]

    @property
    def current_player(self):
        return 0

    def reset(self, seed=0):
        observation, info = self._env.reset(seed=seed)
        reward = 0.0
        terminated = False
        truncated = False

        if self.noop_start:
            noop_steps = np.random.randint(1, 31)
            for _ in range(noop_steps):
                action = 0
                observation, reward, terminated, truncated, info = self._env.step(action)

        self.state = {
            'observation': observation,
            'reward': reward,
            'terminated': terminated,
            'truncated': truncated,
            'info': info,
        }
        self.scores = [0, 0]
        self.steps = 0

        if self.visual_obs:
            self.recorders[0].clear()
            self.image_paths = self._save_image()
        return [self._get_observation(0)]

    def step(self, actions):
        if self.state['terminated'] or self.state['truncated']:
            raise RuntimeError("Cannot apply action on a terminal state.")

        action = actions[0]
        observation, reward, terminated, truncated, info = self._env.step(action)
        self.steps += 1

        if reward == 1.0:
            self.scores[0] += 1.0
        if reward == -1.0:
            self.scores[1] += 1.0
        if self.scores[0] >= self.winning_score or self.scores[1] >= self.winning_score:
            terminated = True
        if self.max_episode_steps and self.steps >= self.max_episode_steps:
            truncated = True
        
        # 检测玩家位置
        self._detect_player_positions(observation)
        
        self.state = {
            'observation': observation,
            'reward': reward,
            'terminated': terminated,
            'truncated': truncated,
            'info': info,
        }
        done = terminated or truncated

        if self.visual_obs:
            self.image_paths = self._save_image()
            if done:
                self.recorders[0].save()

        observations = [self._get_observation(0)]
        rewards = [reward]
        dones = [done] * self.num_agents
        info = self._get_info()
        return observations, rewards, dones, info

    def _get_observation(self, agent_id):
        return Observation(
            obs=self.state['observation'],
            agent_id=agent_id,
            image_paths=self.image_paths,
            legal_actions=self._get_legal_actions(agent_id),
            serialized_state=str(self.state),
            regex_patterns=self.regex_patterns,
            addition_info={
                'position': copy.deepcopy(self.player_positions),
                'ball': copy.deepcopy(self.ball_positon),
                'score': copy.deepcopy(self.scores)
            }
        )

    def _get_info(self):
        if self.state['terminated'] or self.state['truncated']:
            return {
                'returns': self.scores,
                'winner': 0 if self.scores[0] > self.scores[1] else 1 if self.scores[1] > self.scores[0] else -1
            }
        return None

    def _get_legal_actions(self, agent_id):
        return self.action_mapping

    def _action_to_string(self, action):
        return self.action_mapping.get(action, f'UNKNOWN_ACTION_{action}')
    
    def _detect_player_positions(self, frame):
        """Detect the vertical positions (y-coordinates) of paddles on both sides based on grayscale_sum."""
        # Convert to numpy (only take RGB channels, ignore alpha)
        if isinstance(frame, Image.Image):
            frame_np = np.array(frame)
        else:
            frame_np = frame

        if frame_np.ndim != 3 or frame_np.shape[2] < 3:
            raise ValueError("Input frame must be an RGB image.")

        rgb = frame_np[:, :, :3]
        height, width = rgb.shape[:2]

        # Compute grayscale_sum = R + G + B (use larger dtype to avoid overflow)
        grayscale_sum = rgb.sum(axis=2, dtype=np.int16)

        np.savetxt('./log.txt', grayscale_sum, fmt='%d', delimiter=',')

        # Only scan y in [35, 194] range and clip boundaries
        y_start, y_end = 34, 193
        y_start = max(0, y_start)
        y_end = min(height - 1, y_end)

        # Target grayscale sums
        LEFT_SUM = 417
        RIGHT_SUM = 370
        BALL = 708

        left_positions = []
        right_positions = []
        ball_positions_x = []
        ball_positions_y = []

        for y in range(y_start, y_end + 1):
            for x in range(width):
                pixel = grayscale_sum[y, x]
                if pixel == LEFT_SUM:
                    left_positions.append(y)
                if pixel == RIGHT_SUM:
                    right_positions.append(y)
                if pixel == BALL:
                    ball_positions_x.append(x)
                    ball_positions_y.append(y)

        # Use the average row index as each paddle’s center y
        if left_positions:
            self.player_positions[0] = sum(left_positions) / len(left_positions)
        if right_positions:
            self.player_positions[1] = sum(right_positions) / len(right_positions)
        if ball_positions_x:
            self.ball_positon[0] = sum(ball_positions_x) / len(ball_positions_x)
            self.ball_positon[1] = sum(ball_positions_y) / len(ball_positions_y)


    def _save_image(self):
        frame = self.state['observation']
        image_path = []
        step_path = os.path.join(self.image_dir, f'step_{self.steps}')
        os.makedirs(step_path, exist_ok=True)
        
        if self.color_reduction:
            for i in range(frame.shape[-1]):
                # Create image from frame
                img_array = frame[:, :, i]
                image = Image.fromarray(img_array)
                
                # Add coordinate axes to the image
                image_with_axes = self._add_coordinate_axes(image)
                
                image_file = os.path.join(step_path, f'obs_{i}.png')
                image_with_axes.save(image_file)
                image_path.append(image_file)
            self.recorders[0].add_frame(image_file)
        else:
            for i in range(frame.shape[-1] // 3):
                # Create image from frame
                img_array = frame[:, :, i * 3:i * 3 + 3]
                image = Image.fromarray(img_array)
                
                # Add coordinate axes to the image
                image_with_axes = self._add_coordinate_axes(image)
                
                image_file = os.path.join(step_path, f'obs_{i}.png')
                image_with_axes.save(image_file)
                image_path.append(image_file)
            self.recorders[0].add_frame(image_file)
        
        return image_path


    def _add_coordinate_axes(self, image):
        """Add Y-axes on both sides only covering the table (35..194); add a full-width X-axis at the bottom:
        ticks on top, labels below."""
        # Original image size
        orig_width, orig_height = image.size

        # Table visible region (only draw Y-axis/grid in this range)
        field_top = max(0, min(34, orig_height - 1))
        field_bottom = max(field_top + 1, min(194, orig_height - 1))

        # Left/right margin (fixed)
        axis_width = 40

        # Prepare font to estimate bottom margin for labels
        try:
            font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 10)
        except:
            font = ImageFont.load_default()

        # === Calculate bottom margin so numbers below the axis won’t be cut ===
        x_num_ticks = 6
        max_x_label = str(orig_width - 1)
        try:
            bbox = font.getbbox(max_x_label)
            text_height = bbox[3] - bbox[1]
        except Exception:
            text_height = getattr(font, "size", 10)

        tick_len = 8       # tick length
        pad_lbl = 1        # spacing between axis and label
        extra_pad = 4      # extra safe margin
        bottom_margin = text_height + pad_lbl + extra_pad
        axis_y = orig_height  # X-axis line position

        # Create new canvas with left/right margin + bottom margin
        new_width = orig_width + 2 * axis_width
        new_height = orig_height + bottom_margin
        new_image = Image.new('RGB', (new_width, new_height), color='white')

        # Paste original image in the center; convert to RGB if grayscale
        image_rgb = image.convert('RGB') if image.mode != 'RGB' else image
        new_image.paste(image_rgb, (axis_width, 0))

        # Drawing handle
        draw = ImageDraw.Draw(new_image)

        # ===== Y-axis (table area 35..194) =====
        num_ticks = 6
        tick_positions = []
        tick_labels = []
        span = field_bottom - field_top
        for i in range(num_ticks):
            y_pos = field_top + round(i * span / (num_ticks - 1))
            tick_positions.append(y_pos)
            tick_labels.append(str(y_pos))  # Show actual pixel coordinates: 35..194

        left_axis_x = axis_width - 1
        right_axis_x = axis_width + orig_width

        # Left and right Y-axes (limited to table area)
        draw.line([(left_axis_x, field_top), (left_axis_x, field_bottom)], fill='black', width=2)
        draw.line([(right_axis_x, field_top), (right_axis_x, field_bottom)], fill='black', width=2)

        # Draw Y-axis ticks and labels (only within table area)
        for y_pos, label in zip(tick_positions, tick_labels):
            # Left tick
            draw.line([(axis_width - 10, y_pos), (axis_width - 1, y_pos)], fill='black', width=1)
            # Right tick
            draw.line([(right_axis_x, y_pos), (right_axis_x + 10, y_pos)], fill='black', width=1)

            bbox = draw.textbbox((0, 0), label, font=font)
            tw = bbox[2] - bbox[0]
            th = bbox[3] - bbox[1]
            # Left label
            draw.text((axis_width - 15 - tw, y_pos - th // 2), label, fill='black', font=font)
            # Right label
            draw.text((axis_width + orig_width + 15, y_pos - th // 2), label, fill='black', font=font)

        # ===== Bottom X-axis (full width 0..orig_width-1): ticks above, labels below =====
        draw.line([(axis_width, 194), (axis_width + orig_width, 194)], fill='black', width=2)

        for i in range(x_num_ticks):
            x_rel = round(i * (orig_width - 1) / (x_num_ticks - 1))
            x_abs = axis_width + x_rel
            label = str(x_rel)

            # Tick goes upward (above the axis line)
            draw.line([(x_abs, 194), (x_abs, 194 + tick_len)], fill='black', width=1)

            bbox = draw.textbbox((0, 0), label, font=font)
            tw = bbox[2] - bbox[0]
            th = bbox[3] - bbox[1]
            # Label below the axis line
            draw.text((x_abs - tw // 2, 194 + tick_len + pad_lbl), label, fill='black', font=font)

        return new_image
