from matplotlib import pyplot as plt
from overcooked_ai_py.mdp.actions import Action
from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv, DEFAULT_ENV_PARAMS
from overcooked_ai_py.mdp.overcooked_mdp import OvercookedGridworld
from overcooked_ai_py.visualization.state_visualizer import StateVisualizer
from typing import List, Union, Dict, Any
import copy
import cv2
import gym
import numpy as np
import os
import pygame
import re
import json

from envs.base_env import BaseEnv, Observation
from utils.recorder import Recorder


class Overcooked(BaseEnv, env_type="overcooked"):

    def __init__(
        self,
        layout_name="cramped_room",
        max_steps=50,
        visual_obs=True,
        image_dir=None,
        recording_type='gif',
        recording_fps=5,
    ):
        self.done = False
        self.num_steps = 0
        self.max_steps = max_steps
        self.layout_name = layout_name
        self.visualizer = StateVisualizer(is_rendering_hud=False)

        recipe_config = {"cook_time": 5, "delivery_reward": 10}
        rew_shaping_params = {"DISH_PICKUP_REWARD": 2, "SOUP_PICKUP_REWARD": 2, "PLACEMENT_IN_POT_REW": 2}
        self.base_mdp = OvercookedGridworld.from_layout_name(self.layout_name, rew_shaping_params=rew_shaping_params, **recipe_config)
        self.env = OvercookedEnv.from_mdp(self.base_mdp, **DEFAULT_ENV_PARAMS)
        self.state = None
        self.env.reset()

        self.action_space = gym.spaces.Discrete(len(Action.ALL_ACTIONS))
        self.scores = [0, 0]
        self.current_player_idx = 0

        self.player_actions = [Action.STAY, Action.STAY]
        self.action_mapping = {0: '<UP>', 1: '<DOWN>', 2: '<RIGHT>', 3: '<LEFT>', 4: '<STAY>', 5: '<INTERACT>'}
        self.image_paths = []
        self.addition_info = ""
        self.action_history = []
        self.last_game_stats = {}

        self.legend_images_cache = self._create_legend_cache()

        self.visual_obs = visual_obs
        if self.visual_obs:
            assert image_dir is not None, "image_dir must not be None."
            self.image_dir = image_dir
            self.recorders = [Recorder(image_dir, recording_type, recording_fps)]

    def current_player(self):
        return self.current_player_idx

    def reset(self, seed=0) -> List[Union[Observation, None]]:

        self.num_steps = 0
        self.done = False
        self.env.reset()
        self.scores = [0, 0]
        self.current_player_idx = 0
        self.player_actions = [Action.STAY, Action.STAY]

        self.action_history = []

        self.state = {'observation': self.env.state, 'return': [0, 0], 'info': {}}
        if self.visual_obs:
            self.recorders[0].clear()
            self.image_paths = self._save_image()

        return [self._get_observation(0), self._get_observation(1)]

    def is_terminal(self):
        if self.num_steps >= self.max_steps:
            self.done = True

        return self.done

    def step(self, actions):
        if self.is_terminal():
            raise RuntimeError("Cannot apply action on a terminal state.")

        self.num_steps += 1
        choose_action = [self.choose_action_readable(action) for action in actions]
        joint_action = [Action.ALL_ACTIONS[action] for action in actions]

        next_state, reward, done, info = self.env.step(joint_action)

        for key, value in self.env.game_stats.items():
            if key in [
                    "cumulative_shaped_rewards_by_agent",
                    "cumulative_sparse_rewards_by_agent",
            ]:
                previous = self.last_game_stats.get(key, np.zeros_like(value))
                reward = value - previous

                self.last_game_stats[key] = value.copy()
                reward = int(np.sum(reward))
                self.scores[0] += reward
                self.scores[1] += reward

        info = {
            'rewards': [reward, reward],
            'returns': [self.scores[0], self.scores[1]],
            'joint_action': [choose_action[0], choose_action[1]],
        }

        action_info = f"In timestep {self.num_steps - 1}: chef_0 chooses {self.action_mapping[actions[0]]}, chef_1 chooses {self.action_mapping[actions[1]]}."
        self.action_history.append(action_info)

        self.state = {'observation': next_state, 'return': [self.scores[0], self.scores[1]], 'info': info}

        if self.num_steps >= self.max_steps:
            self.done = True
        else:
            self.done = done

        if self.visual_obs:
            self.image_paths = self._save_image()
            if self.done:
                self.recorders[0].save()

        observations = [self._get_observation(0), self._get_observation(1)]
        rewards = [reward, reward]
        dones = [self.done, self.done]

        return observations, rewards, dones, info

    def _get_observation(self, agent_id):
        """Return the Observation object for agent agent_id."""
        play_state = str(self.env.state.players)
        play_txt_state = parse_chef_state(play_state)
        overall_txt_state = str(self.env) + play_txt_state
        recent_actions = self.get_recent_actions(3)
        return Observation(obs=self.state['observation'], agent_id=agent_id, image_paths=self.image_paths,
                           legal_actions=self._get_legal_actions(agent_id), serialized_state=overall_txt_state,
                           regex_patterns=self.regex_patterns, addition_info=recent_actions)

    def _get_legal_actions(self, agent_id):
        return self.action_mapping

    def _get_info(self):
        if self.is_terminal():
            return {
                'returns': self.scores,
            }
        return None

    def create_joint_action(self):
        return [self.player_actions[0], self.player_actions[1]]

    def legal_actions(self):
        return list(range(len(Action.ALL_ACTIONS)))

    def legal_actions_readable(self):
        return [Action.ACTION_TO_CHAR[Action.ALL_ACTIONS[i]] for i in self.legal_actions()]

    def choose_action_readable(self, action):
        return Action.ACTION_TO_CHAR[Action.ALL_ACTIONS[action]]

    def _save_image(self):
        image = self.render()
        combined_image_rgb = self.add_legend_to_image(image, self.legend_images_cache)

        combined_image_bgr = cv2.cvtColor(combined_image_rgb, cv2.COLOR_RGB2BGR)
        image_path = os.path.join(self.image_dir, f"step_{self.num_steps}.png")
        cv2.imwrite(image_path, combined_image_bgr)

        self.recorders[0].add_frame(image_path)
        return [image_path]

    def _create_legend_cache(self):
        legend_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "images", "overcooked"))
        legend_images = {}

        items_to_load = {
            "onion": (100, 100), "dish": (100, 100), "pot": (100, 100),
            "counter": (100, 100), "serving location": (100, 100), "available area": (100, 100),
            "chef_0: facing up": (82, 96), "chef_0: facing down": (82, 96),
            "chef_0: facing left": (82, 96), "chef_0: facing right": (82, 96),
            "chef_1: facing up": (82, 96), "chef_1: facing down": (82, 96),
            "chef_1: facing left": (82, 96), "chef_1: facing right": (82, 96)
        }

        for item, size in items_to_load.items():
            for ext in [".png", ".jpg"]:
                path = os.path.join(legend_path, f"{item}{ext}")
                if os.path.exists(path):
                    img = cv2.imread(path)
                    if img is not None:
                        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                        legend_images[item] = cv2.resize(img_rgb, size)
                        break

        return legend_images

    def add_legend_to_image(self, game_image, legend_cache):
        game_height, game_width = game_image.shape[:2]
        legend_width = 770
        margin = 20
        combined_image = np.ones((max(game_height, 1000), game_width + legend_width, 3), dtype=np.uint8) * 255 
        combined_image[margin:game_height + margin, margin:game_width + margin, :] = game_image

        font = cv2.FONT_HERSHEY_SIMPLEX
        title_x = game_width + 50 + margin
        title_y = 40 + margin
        cv2.putText(combined_image, "Legend", (title_x, title_y), font, 1.5, (0, 0, 0), 2, cv2.LINE_AA)
        cv2.putText(combined_image, f"Step: {self.num_steps}", (35+margin, 60+margin), font, 1.5, (0, 0, 0), 2, cv2.LINE_AA)

        item_x_base = game_width + 50 + margin
        player_x_base = game_width + 380 + margin
        y_item, y_player = title_y + 30 + margin, title_y + 30 + margin

        small_items = ["onion", "dish", "pot", "counter", "serving location", "available area"]
        players = [
            "chef_0: facing up", "chef_0: facing down", "chef_0: facing left", "chef_0: facing right",
            "chef_1: facing up", "chef_1: facing down", "chef_1: facing left", "chef_1: facing right"
        ]

        for item_name in small_items + players:
            if item_name in legend_cache:
                img = legend_cache[item_name]
                h, w = img.shape[:2]
                
                is_player = "chef" in item_name
                x_base = player_x_base if is_player else item_x_base
                y = y_player if is_player else y_item
                
                combined_image[y:y + h, x_base:x_base + w] = img
                cv2.putText(combined_image, item_name, (x_base + w + 10, y + h // 2), font, 0.8, (0, 0, 0), 1, cv2.LINE_AA)

                if is_player:
                    y_player += h + 10
                else:
                    y_item += h + 30
        
        return combined_image


    def render(self):
        rewards_dict = {}  # dictionary of details you want rendered in the UI
        total_score = 0
        for key, value in self.env.game_stats.items():
            if key in [
                    "cumulative_shaped_rewards_by_agent",
                    "cumulative_sparse_rewards_by_agent",
            ]:
                rewards_dict[key] = value
                total_score += int(np.sum(value))

        rewards_dict["score"] = total_score
        image = self.visualizer.render_state(
            state=self.env.state,
            grid=self.env.mdp.terrain_mtx,
            hud_data=StateVisualizer.default_hud_data(self.env.state, **rewards_dict),
        )

        buffer = pygame.surfarray.array3d(image)
        image = np.transpose(buffer, (1, 0, 2))
        image = cv2.resize(image, (2 * 580, 2 * 464))
        return image

    def get_recent_actions(self, num_actions=3):
        if not self.action_history:
            return ""

        recent_actions = self.action_history[-num_actions:] if len(
            self.action_history) > num_actions else self.action_history

        recent_actions = [f"{i+1}. {action}" for i, action in enumerate(recent_actions)]
        return "\n".join(recent_actions)
        
    def get_perception_reward(self, agent_response, ground_truth_state: List[List[str]]):

        patterns = [
            (
                r"grid\s*=\s*(\[\s*\[\s*'[^']*'(?:\s*,\s*'[^']*')*\s*\]"
                r"(?:\s*,\s*\[\s*'[^']*'(?:\s*,\s*'[^']*')*\s*\])*\s*\])",
                lambda s: '{"grid": ' + s.replace("'", '"') + '}'
            ),
            (
                r"(\[\s*\[\s*'[^']*'(?:\s*,\s*'[^']*')*\s*\]"
                r"(?:\s*,\s*\[\s*'[^']*'(?:\s*,\s*'[^']*')*\s*\])*\s*\])",
                lambda s: '{"grid": ' + s.replace("'", '"') + '}'
            ),
            (r'```(?:json)?\s*(\{\s*"grid"\s*:.*?\})\s*```', lambda m: m),  
            (r'\{\s*"grid"\s*:(.*?)\}', lambda m: '{' + '"grid":' + m + '}'), 
            (r'"grid"\s*:\s*(\[\s*\[.*?\]\s*\])', lambda m: m), 
            (r'\[\s*\[\s*"[^"]*"(?:,\s*"[^"]*")*\s*\](?:\s*,\s*\[\s*"[^"]*"(?:,\s*"[^"]*")*\s*\])*\s*\]', lambda m: m),  
        ]
        
        try:
            identified_board = None
            
            json_str = agent_response
            if '```' in json_str:
                json_str = re.sub(r'```(?:json)?', '', json_str)
                json_str = re.sub(r'```', '', json_str)
            
            try:
                json_data = json.loads(json_str)
                if 'grid' in json_data:
                    identified_board = json_data
            except:
                for pattern, processor in patterns:
                    matches = re.findall(pattern, agent_response)
                    if matches:
                        try:
                            board_str = processor(matches[0])
                            
                            if board_str.startswith('{'):
                                identified_board = json.loads(board_str)
                            elif board_str.startswith('['):
                                identified_board = {"grid": json.loads(board_str)}
                            else:
                                identified_board = json.loads(f'{{"grid": {board_str}}}')
                            
                            if identified_board:
                                break
                        except:
                            try:
                                rows = re.findall(r'\[([^\]]+)\]', board_str)
                                if rows:
                                    grid = []
                                    for row in rows:
                                        row = row.replace("'", '"')
                                        cells = re.findall(r'"([^"]*)"', row)
                                        grid.append(cells)
                                    identified_board = {"grid": grid}
                                    break
                            except:
                                continue
            
            if identified_board is None:
                return 0.0
                
            pred_grid = identified_board.get("grid", [])
                
            if not isinstance(pred_grid, list):
                return 0.0
                
            if len(pred_grid) < len(ground_truth_state):
                return 0.0
                
            for row in pred_grid:
                if not isinstance(row, list):
                    return 0.0
            pred_grid = [row[:len(ground_truth_state[0])] for row in pred_grid[:len(ground_truth_state)]]

            for i, row in enumerate(pred_grid):
                if len(row) < len(ground_truth_state[0]):
                    pred_grid[i] = row + ["M"] * (len(ground_truth_state[0]) - len(row))
                    
            total_cells = len(ground_truth_state) * len(ground_truth_state[0])  
            matched_cells = 0
            
            def normalize_cell(cell_str):
                cell_str = str(cell_str).strip().lower()
                
                # Special character replacements - fix encoding issues
                replacements = {
                    'ã¸': 'ø',  # wrong decoding
                    'ã\x83â\x83¸': 'ø',  # another possible wrong decoding
                    'ø}': 'ø',  # remove redundant right brace
                    '{ø}': '{ø',  # remove redundant right brace
                    'â\x86': '↑',  # up arrow
                    'â': '↑',      # another wrong encoding for up arrow
                    'â\x87': '↓',  # down arrow
                    'â\x86\x90': '←',  # left arrow
                    'â\x86\x92': '→',  # right arrow
                }
                
                # Check if wrong encoded arrow characters exist
                if 'â' in cell_str and re.search(r'â[01]', cell_str):
                    chef_num = re.search(r'â([01])', cell_str).group(1)
                    cell_str = cell_str.replace(f'â{chef_num}', f'↑{chef_num}')
                
                for wrong, correct in replacements.items():
                    cell_str = cell_str.replace(wrong, correct)
                    
                # Normalize P{øøø format
                if 'p{' in cell_str or 'p{ø' in cell_str:
                    # remove all braces, rebuild format
                    basic = cell_str.replace('}', '').replace('{', '')
                    
                    # count ø and rebuild format
                    base = 'p{'
                    for i in range(min(basic.count('ø'), 3)):  # at most 3 ø
                        base += 'ø'
                    
                    # check for number or completion mark
                    if re.search(r'\d', basic):
                        base += re.search(r'\d', basic).group(0)
                    elif '✓' in basic or '√' in basic or 'v' in basic:
                        base += '✓'
                        
                    cell_str = base
                
                # Remove possible extra right brace at the end
                if cell_str.endswith('}'):
                    cell_str = cell_str[:-1]
                
                # Handle consistency of chef direction and number
                for direction in ['↑', '↓', '←', '→']:
                    if direction in cell_str and ('0' in cell_str or '1' in cell_str):
                        chef_num = '0' if '0' in cell_str else '1'
                        cell_str = f"{direction}{chef_num}"
                        break
                    
                return cell_str
            
            # Scoring scheme:
            # - Chef cells (2): each 3 points (position 1, direction 1, holding 1), total 6
            # - Pot cells: each 2 points (position 1, onion count/cooking state 1)
            # - Normal cells: each 1 point
            # - Maximum total score: 25
            
            total_score = 0
            max_score = 25
            
            # Locate chef and pot positions
            chef0_pos = None
            chef1_pos = None
            pot_positions = []
            
            # Identify chef and pot positions
            for i in range(len(ground_truth_state)):
                for j in range(len(ground_truth_state[0])):
                    true_cell = normalize_cell(ground_truth_state[i][j])
                    if '0' in true_cell:
                        chef0_pos = (i, j)
                    elif '1' in true_cell:
                        chef1_pos = (i, j)
                    elif 'p' in true_cell.lower():
                        pot_positions.append((i, j))
            
            # Compare each cell
            for i in range(len(ground_truth_state)):
                for j in range(len(ground_truth_state[0])):
                    true_cell = normalize_cell(ground_truth_state[i][j])
                    pred_cell = normalize_cell(pred_grid[i][j])
                    
                    is_chef_cell = (i, j) == chef0_pos or (i, j) == chef1_pos
                    is_pot_cell = (i, j) in pot_positions
                    
                    if is_chef_cell:
                        # Chef cell scoring (3 points)
                        chef_num = '0' if '0' in true_cell else '1'
                        
                        # 1. check if chef number matches (position score)
                        if chef_num in pred_cell:
                            total_score += 1  # position match
                            
                            # 2. check if direction matches
                            true_dir = next((c for c in true_cell if c in '↑↓←→'), '')
                            pred_dir = next((c for c in pred_cell if c in '↑↓←→'), '')
                            if true_dir and pred_dir and true_dir == pred_dir:
                                total_score += 1
                            
                            # 3. check if holding items
                            true_holding = 'o' in true_cell or 'd' in true_cell or '✓' in true_cell
                            pred_holding = 'o' in pred_cell or 'd' in pred_cell or '✓' in pred_cell
                            
                            if true_holding == pred_holding:
                                total_score += 1
                    elif is_pot_cell:
                        # Pot cell scoring (2 points)
                        if 'p' in pred_cell.lower():
                            total_score += 1  # position match
                            
                            # Check empty pot or pot with onions
                            true_empty_pot = (true_cell.lower() == 'p' or 
                                            (true_cell.lower().startswith('p') and '{' not in true_cell))
                            pred_empty_pot = (pred_cell.lower() == 'p' or 
                                            (pred_cell.lower().startswith('p') and '{' not in pred_cell))
                            
                            if true_empty_pot and pred_empty_pot:
                                total_score += 1
                                continue
                                
                            # Onion count or cooking state
                            true_onions = true_cell.count('ø')
                            pred_onions = pred_cell.count('ø')
                            
                            true_cooking = any(c.isdigit() for c in true_cell)
                            pred_cooking = any(c.isdigit() for c in pred_cell)
                            
                            if true_cooking and pred_cooking:
                                total_score += 1
                            elif not true_cooking and not pred_cooking and true_onions == pred_onions:
                                total_score += 1
                    else:
                        # Normal cell scoring (1 point)
                        if true_cell == pred_cell:
                            total_score += 1
            
            accuracy = total_score / max_score
            return accuracy
            
        except Exception as e:
            print(f"Error while computing perception_reward: {e}")
            return 0.0


    @property
    def schema(self):
        from pydantic import BaseModel as PyBase
        class OVERCOOKED(PyBase):
            grid: List[List[str]]
        return OVERCOOKED

def parse_chef_state(players_str: str) -> str:
    """
    eg.
    input:
      "((1, 1) facing (0, -1) holding dish@(1, 1), (2, 1) facing (0, -1) holding None)"
    output:
      chef 0 local in (1, 1), facing ↑, hold dish.
      chef 1 local in (2, 1), facing ↑, hold nothing.
    """
    s = players_str.strip()

    if s.startswith('(') and s.endswith(')'):
        s = s[1:-1].strip()

    pattern = re.compile(r'\(\s*(-?\d+)\s*,\s*(-?\d+)\s*\)'
                         r'\s*facing\s*\(\s*(-?\d+)\s*,\s*(-?\d+)\s*\)'
                         r'\s*holding\s*'
                         r'(.+?)'
                         r'(?=(?:,\s*\(\s*-?\d)|\s*$)')
    matches = pattern.findall(s)
    if not matches:
        return ""

    lines = []
    for idx, (px, py, fx, fy, raw_hold) in enumerate(matches):
        px, py, fx, fy = map(int, (px, py, fx, fy))
        arrow = Action.ACTION_TO_CHAR.get((fx, fy), f'({fx}, {fy})')

        hold_clean = raw_hold.strip()
        prefix_m = re.match(r'^([A-Za-z]+)', hold_clean)
        if prefix_m:
            prefix = prefix_m.group(1).lower()
        else:
            prefix = None

        if prefix == 'none':
            hold_text = 'nothing'
        elif prefix:
            hold_text = prefix
        else:
            hold_text = 'soup'

        lines.append(f'Chef {idx} local in ({px}, {py}), facing {arrow}, hold {hold_text}.')

    return '\n'.join(lines)
