from pettingzoo.butterfly import knights_archers_zombies_v10
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import os
from typing import List, Dict, Any, Union, Tuple

from envs.base_env import BaseEnv, Observation
from utils.recorder import Recorder

class KnightsArchersZombies(BaseEnv, env_type="knights_archers_zombies"):
    '''Knights Archers Zombies (KAZ) environment using PettingZoo.'''

    def __init__(
        self,
        max_steps=100,
        spawn_rate=20,
        num_archers=1,
        num_knights=1,
        max_zombies=10,
        max_arrows=10,
        killable_knights=True,
        killable_archers=True,
        visual_obs=True,
        image_dir=None,
        recording_type='gif',
        recording_fps=20,
        frame_stack=4,
    ):
        # Initialize PettingZoo KAZ environment with image-based observations
        self.env = knights_archers_zombies_v10.env(
            spawn_rate=spawn_rate,
            num_archers=num_archers,
            num_knights=num_knights,
            max_zombies=max_zombies,
            max_arrows=max_arrows,
            killable_knights=killable_knights,
            killable_archers=killable_archers,
            max_cycles=max_steps + spawn_rate,
            vector_state=False,  # Use image-based observations
            render_mode='rgb_array'
        )
        
        self.spawn_rate = spawn_rate
        self.max_steps = max_steps
        self.num_agents = num_archers + num_knights
        self.num_archers = num_archers
        self.num_knights = num_knights
        self.frame_stack = frame_stack
        
        self.visual_obs = visual_obs
        if self.visual_obs:
            assert image_dir is not None, "image_dir must not be None."
            self.image_dir = image_dir
            self.recorders = [Recorder(image_dir, recording_type, recording_fps)]

        self.action_mapping = {
            0: '<MOVE_FORWARD>',   # Move forward
            1: '<MOVE_BACKWARD>',  # Move backward  
            2: '<ROTATE_LEFT>',    # Rotate left (counter-clockwise)
            3: '<ROTATE_RIGHT>',   # Rotate right (clockwise)
            4: '<ATTACK>',         # Attack (knight swings mace, archer shoots arrow)
            5: '<STAY>',           # Stay (do nothing)
        }
        
        self.state = None
        self.score = 0
        self.steps = 0
        self.image_paths = []

    @property
    def current_player(self): # return id
        return self.env.agent_name_mapping[self.env.agent_selection] # map agent_string to agent_id
    
    def reset(self, seed=0):
        self.env.reset(seed=seed)

        # need to stay until the first zombie is out!
        for _ in range(self.spawn_rate * self.num_agents):
            self.env.step(5) # choose to stay for every agents!

        current_agent = self.env.agent_selection # archer_0
        observation, reward, terminated, truncated, info = self.env.last()
        
        self.state = {
            'observation': observation,
            'reward': reward,
            'terminated': terminated,
            'truncated': truncated,
            'info': info,
            'step': 0,
            'current_agent': current_agent
        }
        # print(self.current_player) 0 
        # print(self.env.agent_name_mapping) {'archer_0': 0, 'archer_1': 1, 'knight_0': 2, 'knight_1': 3}
        # print(self.env.agent_selection) archer_0
        
        self.steps = 0
        self.score = 0
        
        if self.visual_obs:
            self.recorders[0].clear()
            self.image_paths = self._save_image()
        return [self._get_observation(i) for i in range(self.num_agents)]

    def step(self, actions):       
        if self.state['terminated'] or self.state['truncated']:
            raise RuntimeError("Cannot apply action on a terminal state.")
                    
        # Apply action for current agent

        step_reward = 0
        for _ in range(len(self.env.agents)): # ['archer_0', 'knight_0']
            if self.env.agent_list[self.current_player].alive:
                action = actions[self.current_player]
                # print('player', self.current_player, 'action', action, end=' ')
                self.env.step(action)
                observation, reward, terminated, truncated, info = self.env.last()
                step_reward += reward
                # print('reward', reward, terminated, truncated)
        
        self.score += step_reward
        # print('alive agent num', len(self.env.agents))
        # print(terminated, truncated)
        # print(self.env.agents)
        # print(self.env.knight_list[0].alive, self.env.archer_list[0].alive)
        # for i in range(4):
        #     print(i, self.env.agent_list[i].alive)
        self.steps += 1
        
        self.state = {
            'observation': observation,
            'reward': step_reward,
            'terminated': terminated,
            'truncated': truncated,
            'info': info,
            'step': self.steps,
            'current_agent': self.env.agent_selection if self.env.agents else None
        }
        
        # Check if game is done
        done = False
        if self.state['terminated'] or self.state['truncated']:
            done = True
        
        if self.visual_obs:
            self.image_paths = self._save_image()
            if done:
                self.recorders[0].save()
        
        # Create observations for all agents
        observations = [self._get_observation(i) for i in range(self.num_agents)]
        rewards = [step_reward] * self.num_agents
        dones = [done] * self.num_agents
        info = self._get_info()
        
        return observations, rewards, dones, info

    def _get_observation(self, agent_id):
        # if agent_id < len(self.agents) and not self.agent_deaths[agent_id]:
        #     agent_type = self.agent_types.get(agent_id, 'unknown')
        # print(self.env.agent_name_mapping) {'archer_0': 0, 'archer_1': 1, 'knight_0': 2, 'knight_1': 3}
        for agent_name, id in self.env.agent_name_mapping.items():
            if agent_id == id:
                role = agent_name.split('_')[0]
            
        return Observation(
            obs=self.state['observation'] if self.state['observation'] is not None else np.zeros((512, 512, 3)),
            agent_id=agent_id,
            image_paths=self.image_paths if self.visual_obs else [],
            legal_actions=self._get_legal_actions(agent_id),
            serialized_state=str(self.state),
            regex_patterns=self.regex_patterns,
            addition_info={
                'step': self.steps,
                # 'knight_angle' : self.env.agent_list[self.env.agent_name_mapping['knight_0']].angle,
                # 'archer_angle' : self.env.agent_list[self.env.agent_name_mapping['archer_0']].angle
                'angle' : self.env.agent_list[self.current_player].angle,
                'role': role
            },
        )

    @property
    def regex_patterns(self):
        patterns = [(r'```json\s*\{\s*"action"\s*:\s*"([^"]+)"\s*\}\s*```', lambda m: m.strip()),
                    (r'"action"\s*:\s*"([^"]+)"', lambda m: m.strip()),
                    (r'<(MOVE_FORWARD|MOVE_BACKWARD|ROTATE_LEFT|ROTATE_RIGHT|STAY|ATTACK)>', lambda m: f"<{m.upper()}>"),
                    (r'\b(MOVE_FORWARD|MOVE_BACKWARD|ROTATE_LEFT|ROTATE_RIGHT|STAY|ATTACK)\b', lambda m: f"<{m.upper()}>"),
                ]
        return patterns
    
    def _get_info(self):
        # if self.state['terminated'] or self.state['truncated']:
        return {
            'returns': self.score
        }

    def _get_legal_actions(self, agent_id):
        return self.action_mapping

    def _save_image(self):
        if not self.visual_obs:
            return []
            
        # Render the current state
        image_path = os.path.join(self.image_dir, f'step_{self.steps}.png')
        os.makedirs(self.image_dir, exist_ok=True)
        
        # Get the rendered image from the environment
        rendered_img = self.env.render()
        if rendered_img is not None:
            img = Image.fromarray(rendered_img)
            img.save(image_path)
            self.recorders[0].add_frame(image_path)
        
        # Collect past frames for frame stacking
        image_paths = []
        for i in range(self.frame_stack):
            step = self.steps - self.frame_stack + 1 + i
            if step >= 0:
                path = os.path.join(self.image_dir, f'step_{step}.png')
                if os.path.exists(path):
                    image_paths.append(path)
                    
        return image_paths

    def get_perception_reward(self, raw_response, label):
        import ast
        import math
        reward = 0
        try:
            # print(type(raw_response))
            if "```json" in raw_response:
                raw_response = raw_response.split("```json")[1].split("```")[0].strip()
            # raw_response = raw_response.split("```json")[1].split("```")[0].strip()
            if type(raw_response) is str:
                raw_response = ast.literal_eval(raw_response)
        except Exception as e:
            print(f"Error parsing raw_response: {e}")
            print(f"raw_response: {raw_response}")
            raw_response = {}

        if raw_response == {}:
            return 0
        
        DIST = 100

        if 'Archer' in raw_response and 'Archer' in label:
            if (isinstance(raw_response['Archer'], (list, tuple)) and 
                isinstance(label['Archer'][0], (list, tuple)) and 
                len(raw_response['Archer']) == len(label['Archer'][0])):
                diff = math.dist(raw_response['Archer'], label['Archer'][0])
                if diff <= DIST:
                    reward += 0.2
                else:
                    reward += 0.2 * DIST / diff

        if 'Knight' in raw_response and 'Knight' in label:
            if (isinstance(raw_response['Knight'], (list, tuple)) and 
                isinstance(label['Knight'][0], (list, tuple)) and 
                len(raw_response['Knight']) == len(label['Knight'][0])):
                diff = math.dist(raw_response['Knight'], label['Knight'][0])
                if diff <= DIST:
                    reward += 0.2
                else:
                    reward += 0.2 * DIST / diff

        if 'Zombies_count' in raw_response and 'Zombies_count' in label:
            if raw_response['Zombies_count'] == label['Zombies_count']:
                reward += 0.2

        count = 0
        if 'Zombies' in raw_response and 'Zombies' in label:
            remaining_labels = label['Zombies'][:]

            for pred_zombie in raw_response['Zombies']:
                if not remaining_labels:
                    break  # No more unmatched labels

                valid_labels = [
                    lbl for lbl in remaining_labels 
                    if isinstance(pred_zombie, (list, tuple)) 
                    and isinstance(lbl, (list, tuple)) 
                    and len(pred_zombie) == len(lbl)
                ]

                if not valid_labels:
                    continue

                min_dist = float('inf')
                closest_label = None

                for lbl_zombie in valid_labels:
                    dx = pred_zombie[0] - lbl_zombie[0]
                    dy = pred_zombie[1] - lbl_zombie[1]
                    dist = math.sqrt(dx * dx + dy * dy)

                    if dist < min_dist:
                        min_dist = dist
                        closest_label = lbl_zombie

                if closest_label is not None:
                    if min_dist <= DIST:
                        count += 1
                        remaining_labels.remove(closest_label)
                    else:
                        count += 1 * DIST / min_dist

        if 'Zombies_count' in label and label['Zombies_count'] > 0:
            reward += 0.4 * (count / label['Zombies_count'])

        return reward


    @property
    def schema(self):
        from pydantic import BaseModel as PyBase
        class KAZ(PyBase):
            Archer: List[float]
            Knight: List[float]
            Zombies_count: int
            Zombies: List[List[float]]
        return KAZ
 
    
class KnightsArchersZombies_GEN_Perception(BaseEnv, env_type="knights_archers_zombies_gen"):
    '''Knights Archers Zombies (KAZ) environment using PettingZoo.'''

    def __init__(
        self,
        max_steps=100,
        spawn_rate=20,
        num_archers=1,
        num_knights=1,
        max_zombies=10,
        max_arrows=10,
        killable_knights=True,
        killable_archers=True,
        visual_obs=True,
        image_dir=None,
        recording_type='gif',
        recording_fps=20,
        frame_stack=4,
    ):
        # Initialize PettingZoo KAZ environment with image-based observations
        self.env = knights_archers_zombies_v10.env(
            spawn_rate=spawn_rate,
            num_archers=num_archers,
            num_knights=num_knights,
            max_zombies=max_zombies,
            max_arrows=max_arrows,
            killable_knights=killable_knights,
            killable_archers=killable_archers,
            max_cycles=max_steps + spawn_rate,
            vector_state=False,  # Use image-based observations
            render_mode='rgb_array'
        )
        
        self.spawn_rate = spawn_rate
        self.max_steps = max_steps
        self.num_agents = num_archers + num_knights
        self.num_archers = num_archers
        self.num_knights = num_knights
        self.frame_stack = frame_stack
        
        self.visual_obs = visual_obs
        if self.visual_obs:
            assert image_dir is not None, "image_dir must not be None."
            self.image_dir = image_dir
            self.recorders = [Recorder(image_dir, recording_type, recording_fps)]

        self.action_mapping = {
            0: '<MOVE_FORWARD>',   # Move forward
            1: '<MOVE_BACKWARD>',  # Move backward  
            2: '<ROTATE_LEFT>',    # Rotate left (counter-clockwise)
            3: '<ROTATE_RIGHT>',   # Rotate right (clockwise)
            4: '<ATTACK>',         # Attack (knight swings mace, archer shoots arrow)
            5: '<STAY>',           # Stay (do nothing)
        }
        
        self.state = None
        self.score = 0
        self.steps = 0
        self.image_paths = []

    @property
    def current_player(self): # return id
        return self.env.agent_name_mapping[self.env.agent_selection] # map agent_string to agent_id
    
    def reset(self, seed=0):
        self.env.reset(seed=seed)

        # need to stay until the first zombie is out!
        for _ in range(self.spawn_rate * self.num_agents):
            self.env.step(5) # choose to stay for every agents!

        current_agent = self.env.agent_selection # archer_0
        observation, reward, terminated, truncated, info = self.env.last()
        
        self.state = {
            'observation': observation,
            'reward': reward,
            'terminated': terminated,
            'truncated': truncated,
            'info': info,
            'step': 0,
            'current_agent': current_agent
        }
        
        self.steps = 0
        self.score = 0
        
        if self.visual_obs:
            self.recorders[0].clear()
            self.image_paths = self._save_image()
        return [self._get_observation(i) for i in range(self.num_agents)]

    def step(self, actions):       
        if self.state['terminated'] or self.state['truncated']:
            raise RuntimeError("Cannot apply action on a terminal state.")
                    
        step_reward = 0
        for _ in range(len(self.env.agents)): # ['archer_0', 'knight_0']
            if self.env.agent_list[self.current_player].alive:
                action = actions[self.current_player]
                self.env.step(action)
                observation, reward, terminated, truncated, info = self.env.last()
                step_reward += reward
        
        self.score += step_reward
        self.steps += 1
        
        self.state = {
            'observation': observation,
            'reward': step_reward,
            'terminated': terminated,
            'truncated': truncated,
            'info': info,
            'step': self.steps,
            'current_agent': self.env.agent_selection if self.env.agents else None
        }
        
        # Check if game is done
        done = False
        if self.state['terminated'] or self.state['truncated']:
            done = True
        
        if self.visual_obs:
            self.image_paths = self._save_image()
            if done:
                self.recorders[0].save()
        
        # Create observations for all agents
        observations = [self._get_observation(i) for i in range(self.num_agents)]
        rewards = [step_reward] * self.num_agents
        dones = [done] * self.num_agents
        info = self._get_info()
        
        return observations, rewards, dones, info

    def _get_observation(self, agent_id):
        for agent_name, id in self.env.agent_name_mapping.items():
            if agent_id == id:
                role = agent_name.split('_')[0]
        
        knights_pos = []
        archers_pos = []
        
        for agent in self.env.agent_list:
            if hasattr(agent, 'is_knight') and agent.is_knight:
                knights_pos.append((agent.rect.x, agent.rect.y))
            elif hasattr(agent, 'is_archer') and agent.is_archer:
                archers_pos.append((agent.rect.x, agent.rect.y))
        
        zombies_pos = []
        for zombie in self.env.zombie_list:
            zombies_pos.append((zombie.rect.x, zombie.rect.y))
        
        return Observation(
            obs=self.state['observation'] if self.state['observation'] is not None else np.zeros((512, 512, 3)),
            agent_id=agent_id,
            image_paths=self.image_paths if self.visual_obs else [],
            legal_actions=self._get_legal_actions(agent_id),
            serialized_state=str(self.state),
            regex_patterns=self.regex_patterns,
            addition_info={
                'step': self.steps,
                'angle': self.env.agent_list[self.current_player].angle,
                'role': role,
                'knights_pos': knights_pos,
                'archers_pos': archers_pos,
                'zombies_count': len(zombies_pos),
                'zombies_pos': zombies_pos
            },
        )

    @property
    def regex_patterns(self):
        patterns = [(r'```json\s*\{\s*"action"\s*:\s*"([^"]+)"\s*\}\s*```', lambda m: m.strip()),
                    (r'"action"\s*:\s*"([^"]+)"', lambda m: m.strip()),
                    (r'<(MOVE_FORWARD|MOVE_BACKWARD|ROTATE_LEFT|ROTATE_RIGHT|STAY|ATTACK)>', lambda m: f"<{m.upper()}>"),
                    (r'\b(MOVE_FORWARD|MOVE_BACKWARD|ROTATE_LEFT|ROTATE_RIGHT|STAY|ATTACK)\b', lambda m: f"<{m.upper()}>"),
                ]
        return patterns
    
    def _get_info(self):
        return {
            'returns': self.score
        }

    def _get_legal_actions(self, agent_id):
        return self.action_mapping

    def _save_image_new(self):
        if not self.visual_obs:
            return []
            
        # Render the current state
        frame = self.env.render()
        image_path = []
        step_path = os.path.join(self.image_dir, f'step_{self.steps}')
        os.makedirs(step_path, exist_ok=True)
        
        if frame is not None:
            # Create image from frame
            image = Image.fromarray(frame)
            
            # Add coordinate axes to the image
            image_with_axes = self._add_coordinate_axes(image)
            
            image_file = os.path.join(step_path, f'obs.png')
            image_with_axes.save(image_file)
            image_path.append(image_file)
            self.recorders[0].add_frame(image_file)
        
        # Collect past frames for frame stacking
        for i in range(self.frame_stack):
            step = self.steps - self.frame_stack + 1 + i
            if step >= 0:
                path = os.path.join(self.image_dir, f'step_{step}', f'obs.png')
                if os.path.exists(path):
                    image_path.append(path)
                    
        return image_path

    def _save_image(self):
        if not self.visual_obs:
            return []
            
        # Render the current state
        image_path = os.path.join(self.image_dir, f'step_{self.steps}.png')
        os.makedirs(self.image_dir, exist_ok=True)
        
        # Get the rendered image from the environment
        rendered_img = self.env.render()
        if rendered_img is not None:
            img = Image.fromarray(rendered_img)
            image_with_axes = self._add_coordinate_axes(img)
            image_with_axes.save(image_path)
            self.recorders[0].add_frame(image_path)
        
        # Collect past frames for frame stacking
        image_paths = []
        for i in range(self.frame_stack):
            step = self.steps - self.frame_stack + 1 + i
            if step >= 0:
                path = os.path.join(self.image_dir, f'step_{step}.png')
                if os.path.exists(path):
                    image_paths.append(path)
                    
        return image_paths

    def _add_coordinate_axes(self, image):
        """Add coordinate axes around the image with 10 ticks."""
        # Original image size
        orig_width, orig_height = image.size
        
        # Calculate new image size (add 55 pixels around for axes)
        axis_width = 55
        new_width = orig_width + 2 * axis_width
        new_height = orig_height + 2 * axis_width
        
        # Create new image
        new_image = Image.new('RGB', (new_width, new_height), color='white')
        
        # Paste original image in the center
        image_rgb = image.convert('RGB') if image.mode != 'RGB' else image
        new_image.paste(image_rgb, (axis_width, axis_width))
        
        # Prepare drawing
        draw = ImageDraw.Draw(new_image)
        
        # Try loading font, fallback to default if unavailable
        try:
            font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20)
        except:
            font = ImageFont.load_default(size=20)
        
        # Number of ticks
        num_ticks = 10
        
        # Compute tick positions and labels
        x_ticks = [int(i * (orig_width - 1) / (num_ticks - 1)) for i in range(num_ticks)]
        y_ticks = [int(i * (orig_height - 1) / (num_ticks - 1)) for i in range(num_ticks)]
        
        # Draw X-axis (bottom)
        x_axis_y = axis_width + orig_height
        draw.line([(axis_width, x_axis_y), (axis_width + orig_width, x_axis_y)], fill='black', width=2)
        
        # Draw X-axis (top)
        draw.line([(axis_width, axis_width), (axis_width + orig_width, axis_width)], fill='black', width=2)
        
        # Draw Y-axis (left)
        draw.line([(axis_width, axis_width), (axis_width, axis_width + orig_height)], fill='black', width=2)
        
        # Draw Y-axis (right)
        y_axis_x = axis_width + orig_width
        draw.line([(y_axis_x, axis_width), (y_axis_x, axis_width + orig_height)], fill='black', width=2)
        
        # Draw X-axis ticks (bottom and top)
        for x in x_ticks:
            x_abs = axis_width + x
            # Bottom tick
            draw.line([(x_abs, x_axis_y), (x_abs, x_axis_y + 8)], fill='black', width=5)
            # Top tick
            draw.line([(x_abs, axis_width), (x_abs, axis_width - 8)], fill='black', width=5)
            
            # Labels (bottom)
            label = str(x)
            bbox = draw.textbbox((0, 0), label, font=font)
            tw = bbox[2] - bbox[0]
            draw.text((x_abs - tw/2, x_axis_y + 10), label, fill='black', font=font)
            
            # Labels (top)
            draw.text((x_abs - tw/2, axis_width - 30), label, fill='black', font=font)
        
        # Draw Y-axis ticks (left and right)
        for y in y_ticks:
            y_abs = axis_width + y
            # Left tick
            draw.line([(axis_width, y_abs), (axis_width - 8, y_abs)], fill='black', width=5)
            # Right tick
            draw.line([(y_axis_x, y_abs), (y_axis_x + 8, y_abs)], fill='black', width=5)

            # Labels (left)
            label = str(y)
            bbox = draw.textbbox((0, 0), label, font=font)
            th = bbox[3] - bbox[1]
            draw.text((axis_width - 15 - bbox[2], y_abs - th/2), label, fill='black', font=font)
            
            # Labels (right)
            draw.text((y_axis_x + 15, y_abs - th/2), label, fill='black', font=font)
        
        return new_image
