#import gym

import numpy as np
import os
import json
import time
import imageio
from PIL import Image, ImageDraw, ImageFont
from rembg import remove
import random
import io
import functools
import shimmy
import gc
import copy
#import pufferlib
#import pufferlib.emulation
#import pufferlib.environments
#import pufferlib.environment
#import pufferlib.postprocess
#import pufferlib.utils

#

import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs.registration import register

#from graphviz import Digraph


    
import math
import random
import pprint

class MCTS:
    def __init__(self, env, focus_mechanic, iterations=10000, simulation_iterations=1000, exploration_weight=1.41, max_mechanics=5):
        self.env = env
        self.focus_mechanic = focus_mechanic
        self.iterations = iterations
        self.simulation_iterations = simulation_iterations
        self.exploration_weight = exploration_weight
        self.mechanics = list(env.mechanic_to_action.keys())
        self.mechanic_to_action = env.mechanic_to_action 
        self.root = None 
        self.last_node = False
        self.max_mechanics = max_mechanics

    def pretty_print_tree(self, node=None, depth=0):
        if node is None:
            node = self.root
        
        indent = "  " * depth
        print(f"{indent}{node}")
        
        for mechanic, child in node.children.items():
            print(f"{indent}  {mechanic} ->")
            self.pretty_print_tree(child, depth + 1)

    def plot_tree(self, max_depth=3):
        dot = Digraph(comment='MCTS Tree')
        dot.attr(rankdir='TB', size='8,8')
        
        def add_nodes_edges(node, parent_id=None, depth=0):
            #if depth > max_depth:
            #    return

            node_id = str(id(node))
            label = f"{node.mechanics[-1] if node.mechanics else 'Root'}\n" \
                    f"Visits: {node.visits}\n" \
                    f"Value: {node.value:.2f}\n" \
                    f"Reward: {node.reward:.2f}"
            
            dot.node(node_id, label)
            
            if parent_id:
                dot.edge(parent_id, node_id)
            
            for mechanic, child in node.children.items():
                add_nodes_edges(child, node_id, depth + 1)
        
        add_nodes_edges(self.root)
        return dot

    def run(self):
        self.root = Node(self.env.clone(), [self.focus_mechanic], self.mechanics)
        start_time = time.time()
        for _i in range(self.iterations):
            #print(f"Iteration {_i}")
            node = self.select(self.root)
            if node.untried_mechanics:
                child = self.expand(node)
                depth, reward = self.simulate(child)
                self.backpropagate(child, depth, reward)
            else:
                # Always run SimulationMCTS for leaf nodes
                simulation_mcts = SimulationMCTS(node.state, node.mechanics, self.simulation_iterations)
                depth, reward = simulation_mcts.run()
                self.backpropagate(node, depth, reward)

        print(f"time taken to run MCTS: {time.time() - start_time}")
        return self.best_child(self.root)

    def expand(self, node):

        if len(node.mechanics) >= self.max_mechanics:
            return node
        
        if node.max_reward > len(node.mechanics)*0.8:
            new_mechanic = random.choice(list(node.untried_mechanics))
            new_mechanics = node.mechanics + [new_mechanic]
            new_state = node.state.clone()
            new_state.reset()
            
            child = Node(new_state, new_mechanics, self.mechanics, parent=node)
            node.children[new_mechanic] = child
            #print(f"Expanded node: {child}")
        
            #best_node = self.find_best_node(self.root)
            #print(f"Best node so far: {best_node}")
            
            node.untried_mechanics.remove(new_mechanic)
            
            return child
        else:
            #print(f"Node not expanded: reward ({node.max_reward}) <= 75% of number of mechanics ({len(node.mechanics) * 0.75}) of total len: {len(node.mechanics)}")
            return node 


    def select(self, node):
        while node.children and not node.is_terminal:
            if node.untried_mechanics or node.visits <= 1: 
                return node
            node = self.best_uct(node)
        return node

    def simulate(self, node):
        state = node.state.clone()
        total_depth = 0
        total_reward = 0
        used_mechanics = set(node.mechanics)
        available_mechanics = set(self.mechanics) - used_mechanics
        max_depth = 0
        max_reward = 0
        reward_per_mechanic = 0
        reward_per_mechanic_counter = 1
        while available_mechanics:
            
            simulation_mcts = SimulationMCTS(state, list(used_mechanics), self.simulation_iterations)
            depth, reward = simulation_mcts.run()
            reward_per_mechanic += reward/(len(used_mechanics))
            #print("\n")
            #print("****************************************************")
            #print(f"list(used_mechanics): {used_mechanics}")
            ##print(f"len(used_mechanics): {len(used_mechanics)}")
            #print(f"depth: {depth}")
            #print(f"reward: {reward}")
            #print(f"normalised rewards: {reward/(len(used_mechanics))}")
            #print("****************************************************")
            #print("\n")
            max_depth = max(max_depth, depth)
            max_reward = max(max_reward, reward)
 
            total_depth += depth
            total_reward += reward
            
            # Choose a new mechanic to add
            new_mechanic = random.choice(list(available_mechanics))
            used_mechanics.add(new_mechanic)
            available_mechanics.remove(new_mechanic)

            #if len(available_mechanics) == 0:
            #    self.last_node = True
            #    break
            
            reward_per_mechanic_counter += 1
        #node.reward_per_mechanic = reward_per_mechanic/reward_per_mechanic_counter
        #reward_per_sim = total_reward/reward_per_mechanic_counter
        return max_depth, max_reward
    
    def backpropagate(self, node, depth, reward):
        while node:
            node.visits += 1
            node.value += depth
            node.reward += reward
            node.max_reward = reward
            node = node.parent

    def best_uct(self, node):
        return max(node.children.values(), key=lambda c: c.uct(self.exploration_weight))

    def best_child(self, node):
        return max(node.children.items(), key=lambda c: c[1].visits)
    #def best_child(self, node):
        #if not node.children:
        #    return None, None

        def score(child):
            # Combine visits, value, and reward, then normalize by the number of nodes
            return child.value/child.visits if child.visits > 0 else 0#(child.visits + child.value + child.reward) / len(child.mechanics)

        return max(node.children.items(), key=lambda c: score(c[1]))
    def find_best_node(self, node):
        #if not node.children:
        #    return node
        return self.best_child(node)

    def top_n_children(self, node, number_of_children=5):
        if not node.children:
            return [], [], []
        
        sorted_by_visits = sorted(node.children.items(), key=lambda c: c[1].visits, reverse=True)[:number_of_children]
        sorted_by_value = sorted(node.children.items(), key=lambda c: c[1].value, reverse=True)[:number_of_children]
        sorted_by_reward = sorted(node.children.items(), key=lambda c: c[1].reward, reverse=True)[:number_of_children]
        
        return sorted_by_visits, sorted_by_value, sorted_by_reward
    
    def get_all_nodes(self):
        all_nodes = []
        def traverse(node):
            all_nodes.append(node)
            for child in node.children.values():
                traverse(child)
        traverse(self.root)
        return all_nodes
    

class SimulationMCTS:
    def __init__(self, env, available_mechanics, iterations=100, exploration_weight=1.41):
        self.env = env
        self.available_mechanics = available_mechanics
        self.iterations = iterations
        self.exploration_weight = exploration_weight
        self.mechanic_to_action = env.mechanic_to_action
        self.root = SimulationNode(env.clone(), parent=None)
        self.max_depth = 0 

    def run(self):
        total_reward = 0
        total_depth = 0
        for _ in range(self.iterations):
            node = self.select(self.root)
            if not node.is_terminal:
                node = self.expand(node)
            reward = self.rollout(node)
            self.backpropagate(node, reward)

            current_depth = self.get_depth(node)
            if current_depth > self.max_depth:
                self.max_depth = current_depth

            total_reward += reward

        #return self.max_depth if self.root.visits > 0 else 0, total_reward
        return self.max_depth, self.root.value / self.root.visits if self.root.visits > 0 else 0

    def select(self, node):
        while node.children and not node.is_terminal:
            if len(node.children) < len(self.available_mechanics):
                return node
            node = self.best_uct(node)
        return node

    def expand(self, node):
        untried_mechanics = set(self.available_mechanics) - set(m for m, a in self.mechanic_to_action.items() if a in node.children)
        mechanic = random.choice(list(untried_mechanics))
        
        if mechanic == "move":
            action = random.randint(0, 3) 
        else:
            action = self.mechanic_to_action[mechanic]
        
        new_state = node.state.clone()
        new_state.reset()
        obs, reward, done, _, _ = new_state.step(action)
        child = SimulationNode(new_state, parent=node, action=action)
        child.is_terminal = done
        node.children[action] = child
        return child

    def rollout(self, node):
        state = node.state.clone()
        state.reset()
        total_reward = 0
        depth = 0
        max_depth = 8
        while depth < max_depth:  # Limit depth to avoid infinite loops
            action = self.rollout_policy(state)
            obs, reward, done, _, _ = state.step(action)
            total_reward += reward
            depth += 1
            if done:
                break
        
        return total_reward
    def rollout_policy(self, state):
        chosen_mechanic = random.choice(self.available_mechanics)
        if chosen_mechanic == "move":
            return random.randint(0, 3)  # Choose randomly between 0, 1, 2, 3 for move actions
        elif chosen_mechanic in self.mechanic_to_action:
            return self.mechanic_to_action[chosen_mechanic]
        else:
            return random.randint(0, state.action_space.n - 1)

    def backpropagate(self, node, reward):
        while node:
            node.visits += 1
            node.value += reward
            node = node.parent

    def best_uct(self, node):
        return max(node.children.values(), key=lambda c: c.uct(self.exploration_weight))

    def get_max_depth(self):
        return self.max_depth

    def get_depth(self, node):
        depth = 0
        while node.parent:
            depth += 1
            node = node.parent
        return depth
    
class MechEnv(gym.Env):
    def __init__(self, walkable_tiles, tiles_without_char, tiles, str_map_without_chars, str_map, interactive_object_tiles, enemy_tiles, render_mode="rgb_array"):
        super(MechEnv, self).__init__()
        
        self.map_str_without_chars = str_map_without_chars.strip().split('\n')
        self.map_str = str_map.strip().split('\n')
        self.map = [list(row) for row in self.map_str]
        self.map_without_chars = [list(row) for row in self.map_str_without_chars]

        self.tile_size = 16
        self.char_tile_size = 16
        self.tiles = tiles
        self.tiles_without_char = tiles_without_char
        self.action_space = spaces.Discrete(9)  # Up, down, left, right, pick, hit
        self.char_set = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'O': 4, '@': 5, '#': 6, '&': 7}
        self.char_to_int = lambda c: self.char_set.get(c, 0)

        self.mechanic_to_action = {
            "move": 0,  # 0-3 for movement
            "pick": 4,
            "hit": 5,
            "push": 6,
            "teleport": 7,
            "flip": 8
        }

        max_width = max(len(row) for row in self.map_str)
        self.observation_space = spaces.Box(
            low=0, 
            high=1,
            shape=(len(self.char_set), len(self.map_str), max_width),  # Use len(self.char_set) for channels
            dtype=np.int32
        )
        
        self.default_walkable_tile = "A"
        self.reward = 0
        self.walkable_tiles = walkable_tiles
        self.interactive_object_tiles = interactive_object_tiles
        self.enemy_tiles = enemy_tiles
        self.picked_objects = []
        self.picked_object = False
        self.npc_tiles = ["&"]
        self.enemy_hit = False
        self.player_health = 100
        self.enemy_health = 100
        self.current_score = 0
        self.explored_tiles = set()
        self.frames = [] 
        self.video_buffer = None
        self.render_mode = render_mode
        self.teleport_cooldown = 0
        self.reset()

    def reset(self, seed=None):
        self.map = [list(row) for row in self.map_str]
        self.map_without_chars = [list(row) for row in self.map_str_without_chars]
        self.grid_width = max(len(row) for row in self.map)
        self.default_walkable_tile = "A"
        self.grid_height = len(self.map)
        self.player_position = self.find_player_position()
        self.current_tile = self.map_without_chars[self.player_position[0]][self.player_position[1]]
        self.picked_object = False
        self.enemy_hit = False
        self.reward = 0
        self.frames = [] 
        self.video_buffer = None
        self.teleport_cooldown = 0
        self.gravity_direction = 0 
        self.gravity_flipped = False
        state = self.get_state()["map"]
        return state

    def get_state(self):
        max_width = max(len(row) for row in self.map)
        int_map = [
            [self.char_to_int(char) for char in row] + [0] * (max_width - len(row))
            for row in self.map
        ]
        int_map_array = np.array(int_map, dtype=np.int32)
        one_hot_map = np.eye(len(self.char_set), dtype=np.int32)[int_map_array]
        int_map_with_channel = np.transpose(one_hot_map, (2, 0, 1))
        
        return {"map": self.map}
    
    def step(self, actions):
        self.reward = 0
        self.enemy_hit = False
        done = False
        if actions < 4:  # Movement actions
            self.reward += self.move_player(actions)
        elif actions == 4:  # Pick action
            self.reward += self.pick_object() 
        elif actions == 5:  # Hit action
            self.reward += self.hit_enemy()
            #done = self.enemy_hit
        elif actions == 6:  # Push action
            self.reward += self.push_object()
        elif actions == 7:  # Teleport action
            #if self.teleport_cooldown == 0:
            self.reward += self.teleport()
        elif actions == 8:  # Flip gravity action
            self.reward += self.flip_gravity()
    
    # Apply gravity effects after each action
        self.apply_gravity_effects()
    
        info = {"success": 1 if self.enemy_hit else 0, "action_taken": actions}
        state = self.get_state()["map"]
        return state, self.reward, done, False, info
    
    def render(self):
        if self.render_mode != "rgb_array":
            return None
        env_img = Image.new('RGBA', (len(self.map[0]) * self.tile_size, len(self.map) * self.tile_size))

        # 1st layer: Default walkable tile
        for i in range(len(self.map)):
            for j in range(len(self.map[0])):
                tile_img = self.tiles[self.default_walkable_tile].resize((self.tile_size, self.tile_size))
                env_img.paste(tile_img, (j * self.tile_size, i * self.tile_size), tile_img)
        
        # 2nd layer: Map without characters
        for i, row in enumerate(self.map_without_chars):
            for j, tile in enumerate(row):
                if tile in self.tiles and tile != self.default_walkable_tile:
                    tile_img = self.tiles[tile].resize((self.tile_size, self.tile_size))
                    env_img.paste(tile_img, (j * self.tile_size, i * self.tile_size), tile_img)
        
        # 3rd layer: Characters and objects
        for i, row in enumerate(self.map):
            for j, tile in enumerate(row):
                if tile in self.tiles and tile not in self.walkable_tiles:
                    if tile.isalpha():
                        tile_img = self.tiles[tile].resize((self.tile_size, self.tile_size))
                        env_img.paste(tile_img, (j * self.tile_size + x_offset, i * self.tile_size + y_offset), tile_img)
                    else:
                        tile_img = self.tiles[tile].resize((self.char_tile_size, self.char_tile_size))
                        # Center the character in the tile
                        x_offset = (self.tile_size - self.char_tile_size) // 2
                        y_offset = (self.tile_size - self.char_tile_size) // 2
                        env_img.paste(tile_img, (j * self.tile_size + x_offset, i * self.tile_size + y_offset), tile_img)
        
        draw = ImageDraw.Draw(env_img)
        font = ImageFont.load_default()
        text = f"Objects Picked: {len(self.picked_objects)}"
        draw.text((10, env_img.size[1] - 20), text, (255, 255, 255), font=font)
        
        frame = np.array(env_img.convert('RGB'))
        self.frames.append(frame)
        return frame
    def _get_video(self, fps=10):
        if not self.frames:
            return None
        
        if self.video_buffer is None:
            self.video_buffer = io.BytesIO()
            with imageio.get_writer(self.video_buffer, format='mp4', fps=fps) as writer:
                for frame in self.frames:
                    writer.append_data(frame)
            self.video_buffer.seek(0)
        
        return self.video_buffer
    def get_video(self):
        return self._get_video()
    
    def get_player_position(self):
        return self.player_position
    def move_player(self, action):
        moves = {0: (-1, 0), 1: (1, 0), 2: (0, -1), 3: (0, 1)}  # Up, Down, Left, Right
        dx, dy = moves[action]
        
        # Adjust movement based on gravity direction
        if self.gravity_direction == 0:  # Down
            new_row, new_col = self.player_position[0] + dx, self.player_position[1] + dy
        elif self.gravity_direction == 1:  # Right
            new_row, new_col = self.player_position[0] + dy, self.player_position[1] + dx
        elif self.gravity_direction == 2:  # Up
            new_row, new_col = self.player_position[0] - dx, self.player_position[1] - dy
        else:  # Left
            new_row, new_col = self.player_position[0] - dy, self.player_position[1] - dx
        
        reward = 0
        if 0 <= new_row < self.grid_height and 0 <= new_col < self.grid_width:
            new_tile = self.map[new_row][new_col]
            if new_tile in self.walkable_tiles:
                self.update_player_position(new_row, new_col)
                reward = 1
        return reward
        
    def pick_object(self):
        reward = 0
        self.picked_object = False
        #if not self.picked_object:
        adjacent_positions = [(0, -1), (0, 1), (-1, 0), (1, 0)]  # Up, Down, Left, Right
        for dx, dy in adjacent_positions:
            x, y = self.player_position
            new_x = x + dx
            new_y = y + dy
            if 0 <= new_x < self.grid_width and 0 <= new_y < self.grid_height:
                target_tile = self.map[new_y][new_x]
                if target_tile in self.interactive_object_tiles:
                    #print("Picked an object!")
                    self.map[new_y][new_x] = self.default_walkable_tile 
                    reward = 1
                    self.picked_object = True
                    break 
        return reward
    def hit_enemy(self):
        reward = 0
        self.enemy_hit = False
        adjacent_positions = [(0, -1), (0, 1), (-1, 0), (1, 0)]  # Up, Down, Left, Right
        for dx, dy in adjacent_positions:
            x, y = self.player_position
            new_x = x + dx
            new_y = y + dy
            if 0 <= new_x < self.grid_width and 0 <= new_y < self.grid_height:
                target_tile = self.map[new_y][new_x]
                if target_tile in self.enemy_tiles: 
                    #print("Hit an enemy!")
                    #if self.picked_object:
                    reward = 1
                    self.enemy_hit = True
                    #self.success_rate += 1
                    self.map[new_y][new_x] = self.default_walkable_tile 
                    #else:
                    #    reward = -1
                    #reward = 10
                    #self.enemy_hit = True
                    
                    break
        return reward
    
    def push_object(self):
        reward = 0
        adjacent_positions = [(0, -1), (0, 1), (-1, 0), (1, 0)]  # Up, Down, Left, Right
        for dx, dy in adjacent_positions:
            x, y = self.player_position
            new_x, new_y = x + dx, y + dy
            if 0 <= new_x < self.grid_width and 0 <= new_y < self.grid_height:
                target_tile = self.map[new_y][new_x]
                if target_tile in self.interactive_object_tiles:
                    push_x, push_y = new_x + dx, new_y + dy
                    if 0 <= push_x < self.grid_width and 0 <= push_y < self.grid_height:
                        if self.map[push_y][push_x] in self.walkable_tiles:
                            self.map[push_y][push_x] = target_tile
                            self.map[new_y][new_x] = self.default_walkable_tile
                            reward = 1
                            break
        return reward
    
    def teleport(self):
        reward = 0
        teleport_range = 3  # Maximum teleport distance
        
        # Find all valid teleport locations
        valid_locations = []
        for dy in range(-teleport_range, teleport_range + 1):
            for dx in range(-teleport_range, teleport_range + 1):
                new_x, new_y = self.player_position[1] + dx, self.player_position[0] + dy
                if 0 <= new_x < self.grid_width and 0 <= new_y < self.grid_height:
                    if self.map[new_y][new_x] in self.walkable_tiles:
                        valid_locations.append((new_y, new_x))
        
        if valid_locations:
            # Choose a random valid location
            new_y, new_x = random.choice(valid_locations)
            
            # Teleport the player
            self.update_player_position(new_y, new_x)
            
            # Add a cooldown effect
            self.teleport_cooldown = 5  # Number of turns before teleport can be used again
            
            reward = 1  # Reward for successful teleport
        else:
            reward = -1  # Penalty for failed teleport attempt
        
        return reward
    
    def flip_gravity(self):
        reward = 0
        if not self.gravity_flipped:
            self.gravity_direction = (self.gravity_direction + 2) % 4  # Cycle through 4 directions
            self.gravity_flipped = True
            self.apply_gravity_effects()
            reward = 1
        else:
            reward = -1  # Penalty for trying to flip again too soon
        return reward

    def update_player_position(self, new_row, new_col):
        old_row, old_col = self.player_position
        self.map[old_row][old_col] = self.current_tile
        self.current_tile = self.map_without_chars[new_row][new_col]
        self.map[new_row][new_col] = '@'
        self.player_position = (new_row, new_col)
    
    def find_player_position(self):
        for i, row in enumerate(self.map):
            for j, tile in enumerate(row):
                if tile == '@':
                    return (i, j)
        return None
    
    def clone(self):

        new_env = MechEnv(
            walkable_tiles=self.walkable_tiles,
            tiles_without_char=self.tiles_without_char,
            tiles=self.tiles,
            str_map_without_chars='\n'.join(self.map_str_without_chars),
            str_map='\n'.join(self.map_str),
            interactive_object_tiles=self.interactive_object_tiles,
            enemy_tiles=self.enemy_tiles
        )

        new_env.map = [row[:] for row in self.map]
        new_env.map_without_chars = [row[:] for row in self.map_without_chars]
        new_env.player_position = self.player_position
        new_env.current_tile = self.current_tile
        new_env.picked_objects = self.picked_objects.copy()
        new_env.enemy_hit = self.enemy_hit
        new_env.player_health = self.player_health
        new_env.enemy_health = self.enemy_health
        new_env.current_score = self.current_score
        new_env.explored_tiles = self.explored_tiles.copy()
        new_env.char_to_int = self.char_to_int
        new_env.char_set = self.char_set
        return new_env
    
    #def is_terminal(self):
    #    if self.picked_object:
    #        if self.enemy_hit:
    #            return True
    #    return False
    
    def close(self):
        self.map = None
        self.map_without_chars = None
        self.tiles = None
        self.tiles_without_char = None
        self.picked_objects.clear()
        self.explored_tiles.clear()
        self.player_position = None
        self.current_tile = None
        self.enemy_hit = False
        self.picked_object = False
        self.player_health = 100
        self.enemy_health = 100
        self.current_score = 0
        self.frames = [] 
        self.video_buffer = None

    def add_mechanic(self, mechanic):
        # This method doesn't need to do anything for now, 
        # as the mechanics are already available in the environment.
        # We're just adding it to prevent the AttributeError.
        pass

    def apply_gravity_effects(self):
        directions = [(1, 0), (0, 1), (-1, 0), (0, -1)]  # Down, Right, Up, Left
        dx, dy = directions[self.gravity_direction]
        for y in range(self.grid_height):
            for x in range(self.grid_width):
                if self.map[y][x] in self.interactive_object_tiles:
                    new_y, new_x = y + dy, x + dx
                    if 0 <= new_y < self.grid_height and 0 <= new_x < self.grid_width:
                        if self.map[new_y][new_x] in self.walkable_tiles:
                            self.map[new_y][new_x] = self.map[y][x]
                            self.map[y][x] = self.default_walkable_tile
    
class GameEnv(gym.Env):
    def __init__(self, walkable_tiles, tiles_without_char, tiles, str_map_without_chars, str_map, interactive_object_tiles, enemy_tiles, collectible_tiles,render_mode="rgb_array"):
        super(GameEnv, self).__init__()
        
        self.map_str_without_chars = str_map_without_chars.strip().split('\n')
        self.map_str = str_map.strip().split('\n')
        self.map = [list(row) for row in self.map_str]
        self.map_without_chars = [list(row) for row in self.map_str_without_chars]

        self.tile_size = 16
        self.char_tile_size = 16
        self.tiles = tiles
        self.tiles_without_char = tiles_without_char
        self.char_set = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'O': 4, '@': 5, '#': 6, '&': 7}
        self.char_to_int = lambda c: self.char_set.get(c, 0)
        self.action_space = spaces.Discrete(self.get_action_space())
        self.mechanic_to_action = self.get_mechanics_to_action()

        max_width = max(len(row) for row in self.map_str)
        self.observation_space = spaces.Box(
            low=0, 
            high=1,
            shape=(len(self.char_set), len(self.map_str), max_width),  # Use len(self.char_set) for channels
            dtype=np.int32
        )
        
        self.default_walkable_tile = "A"
        self.reward = 0
        self.walkable_tiles = walkable_tiles
        self.interactive_object_tiles = interactive_object_tiles
        self.enemy_tiles = enemy_tiles
        self.picked_objects = []
        self.object_picked = False
        self.object_pushed = False
        self.npc_tiles = ["&"]
        self.enemy_hit = False
        self.current_score = 0
        self.explored_tiles = set()
        self.frames = [] 
        self.video_buffer = None
        self.render_mode = render_mode
        self.done = False
        self.collectible_tiles = collectible_tiles
        self.collected_items = 0
        self.target_positions = self.find_target_positions()
        self.objects_on_target = 0
        self.reset()

    def reset(self, seed=None):
        self.map = [list(row) for row in self.map_str]
        self.map_without_chars = [list(row) for row in self.map_str_without_chars]
        self.grid_width = max(len(row) for row in self.map)
        self.default_walkable_tile = "A"
        self.grid_height = len(self.map)
        self.player_position = self.find_player_position()
        self.current_tile = self.map_without_chars[self.player_position[0]][self.player_position[1]]
        self.object_picked = False
        self.object_pushed = False
        self.enemy_hit = False
        self.reward = 0
        self.frames = [] 
        self.done = False
        self.video_buffer = None
        state = self.get_state()["map"]
        return state

    def step(self, actions):
        self.reward = 0
        #self.enemy_hit = False
        #self.object_picked = False
        self.done = False
        if actions < 4:  # Movement actions
            self.reward += self.move_player(actions)
        elif actions == 4:  # Push action
            self.reward += self.push_object()
        elif actions == 5:  # Hit action
            self.reward += self.collect_item()
            #print(f"object pushed: {self.object_pushed}")
        self.done = self.is_terminal()
        info = {"success": 1 if self.done else 0, "action_taken": actions}
        state = self.get_state()["map"]
        return state, self.reward, self.done, False, info
    
    def render(self):
        if self.render_mode != "rgb_array":
            return None
        env_img = Image.new('RGBA', (len(self.map[0]) * self.tile_size, len(self.map) * self.tile_size))

        # 1st layer: Default walkable tile
        for i in range(len(self.map)):
            for j in range(len(self.map[0])):
                tile_img = self.tiles[self.default_walkable_tile].resize((self.tile_size, self.tile_size))
                env_img.paste(tile_img, (j * self.tile_size, i * self.tile_size), tile_img)
        
        # 2nd layer: Map without characters
        for i, row in enumerate(self.map_without_chars):
            for j, tile in enumerate(row):
                if tile in self.tiles and tile != self.default_walkable_tile:
                    tile_img = self.tiles[tile].resize((self.tile_size, self.tile_size))
                    env_img.paste(tile_img, (j * self.tile_size, i * self.tile_size), tile_img)
        
        # 3rd layer: Characters and objects
        for i, row in enumerate(self.map):
            for j, tile in enumerate(row):
                if tile in self.tiles and tile not in self.walkable_tiles:
                    if tile.isalpha():
                        tile_img = self.tiles[tile].resize((self.tile_size, self.tile_size))
                        env_img.paste(tile_img, (j * self.tile_size + x_offset, i * self.tile_size + y_offset), tile_img)
                    else:
                        tile_img = self.tiles[tile].resize((self.char_tile_size, self.char_tile_size))
                        # Center the character in the tile
                        x_offset = (self.tile_size - self.char_tile_size) // 2
                        y_offset = (self.tile_size - self.char_tile_size) // 2
                        env_img.paste(tile_img, (j * self.tile_size + x_offset, i * self.tile_size + y_offset), tile_img)
        
        draw = ImageDraw.Draw(env_img)
        font = ImageFont.load_default()
        text = f"Objects Picked: {len(self.picked_objects)}"
        draw.text((10, env_img.size[1] - 20), text, (255, 255, 255), font=font)
        
        frame = np.array(env_img.convert('RGB'))
        self.frames.append(frame)
        return frame
    
    def is_terminal(self):
        # Check if all collectible items have been collected
        if self.collected_items >= self.count_collectibles():
            return True
        
        # Check if all objects have been pushed to their target positions
        if self.objects_on_target == len(self.target_positions):
            return True
    
        return False

    def find_target_positions(self):
        targets = []
        for y, row in enumerate(self.map):
            for x, tile in enumerate(row):
                if tile == 'X':  # Assuming 'X' represents target positions
                    targets.append((y, x))
        return targets

    def count_collectibles(self):
        return sum(row.count(tile) for row in self.map for tile in self.collectible_tiles)

    
    def update_player_position(self, new_row, new_col):
        old_row, old_col = self.player_position
        self.map[old_row][old_col] = self.current_tile
        self.current_tile = self.map_without_chars[new_row][new_col]
        self.map[new_row][new_col] = '@'
        self.player_position = (new_row, new_col)
    
    def find_player_position(self):
        for i, row in enumerate(self.map):
            for j, tile in enumerate(row):
                if tile == '@':
                    return (i, j)
        return None
    
    def clone(self):

        new_env = object.__new__(GameEnv)
        
        # Copy all attributes from self to new_env
        for attr, value in self.__dict__.items():
            if attr not in ['action_space', 'observation_space']:
                setattr(new_env, attr, copy.deepcopy(value))
        
        # Recreate action_space and observation_space
        new_env.action_space = spaces.Discrete(new_env.get_action_space())
        new_env.observation_space = spaces.Box(
            low=0, 
            high=1,
            shape=(len(new_env.char_set), len(new_env.map_str), max(len(row) for row in new_env.map_str)),
            dtype=np.int32
        )
        
        return new_env

    
    def get_player_position(self):
        return self.find_player_position()

    def get_state(self):
        return {"map": self.map}

    def get_mechanics_to_action(self):
        return {
            "move_player": 0,  # 0-3 for movement
            "push_object": 4,
            "collect_item": 5
        }
    
    def get_action_space(self):
        return 6
    
    def move_player(self, action):
        moves = {0: (-1, 0), 1: (1, 0), 2: (0, -1), 3: (0, 1)}  # Up, Down, Left, Right
        dx, dy = moves[action]
        
        new_row, new_col = self.player_position[0] + dx, self.player_position[1] + dy
        
        reward = 0
        if 0 <= new_row < self.grid_height and 0 <= new_col < self.grid_width:
            new_tile = self.map[new_row][new_col]
            if new_tile in self.walkable_tiles:
                self.update_player_position(new_row, new_col)
                reward = 1
        return reward
    
    def push_object(self):
        reward = 0
        adjacent_positions = [(0, -1), (0, 1), (-1, 0), (1, 0)]  # Up, Down, Left, Right
        for dx, dy in adjacent_positions:
            x, y = self.player_position
            new_x, new_y = x + dx, y + dy
            if 0 <= new_x < self.grid_width and 0 <= new_y < self.grid_height:
                target_tile = self.map[new_y][new_x]
                if target_tile in self.interactive_object_tiles:
                    push_x, push_y = new_x + dx, new_y + dy
                    if 0 <= push_x < self.grid_width and 0 <= push_y < self.grid_height:
                        if self.map[push_y][push_x] in self.walkable_tiles:
                            self.map[push_y][push_x] = target_tile
                            self.map[new_y][new_x] = self.default_walkable_tile
                            reward = 1
                            if (push_y, push_x) in self.target_positions:
                                #print(f"Object {target_tile} pushed to target position {push_y, push_x}")
                                self.objects_on_target += 1
                                reward = 5  # Extra reward for pushing object to target
                            break
        return reward
    
    def collect_item(self):
        reward = 0
        x, y = self.player_position
        if 0 <= y < len(self.map) and 0 <= x < len(self.map[y]):
            if self.map[y][x] in self.collectible_tiles:
                self.map[y][x] = self.default_walkable_tile
                self.collected_items += 1
                reward = 3
        #else:
        #    print(f"Warning: Player position ({x}, {y}) is out of bounds.")
        return reward
    
# Create random agent
class RandomAgent:
    def __init__(self, game_env):
        self.env = game_env.clone()
        self.action_space = self.env.get_action_space()
        self._is_cleaned_up = False
        
    def run(self, num_steps=15):
        state = self.env.reset()
        total_reward = 0
        best_rewards = []
        best_action_sequence = []
        done = False
        
        for _ in range(num_steps):
            if done:
                break
            action = random.randint(0, self.action_space-1)
            _, reward, done, _, _ = self.env.step(action)
            total_reward += reward
            best_rewards.append(reward)
            best_action_sequence.append(action)
                
        return None, None, best_action_sequence, total_reward, best_rewards, done

    def cleanup(self):
        """Clean up resources and memory"""
        if self._is_cleaned_up:
            return
        
        # Clean up environment if it has a cleanup method
        if hasattr(self.env, 'cleanup'):
            self.env.cleanup()
        
        # Clear all references
        self.env = None
        self.action_space = None
        
        # Force garbage collection
        gc.collect()
        
        self._is_cleaned_up = True

    def __del__(self):
        """Destructor to ensure cleanup is called"""
        self.cleanup()
    
class PlayMCTS:
    def __init__(self, env, iterations=100, exploration_weight=1.81, early_stop=False):
        self.env = env
        self.iterations = iterations
        self.exploration_weight = exploration_weight
        self.root = PlayNode(self.env.clone(), parent=None)
        self.max_depth = 0
        self.best_node = None 
        self.action_counts = {i: 0 for i in range(env.action_space.n)}
        self._is_cleaned_up = False
        self.early_stop = early_stop

    def cleanup(self):
        """Clean up resources and memory"""
        if self._is_cleaned_up:
            return

        # Recursively clear the tree
        def clear_node(node):
            if node:
                # First cleanup children recursively
                for child in list(node.children.values()):
                    clear_node(child)
                # Clear state
                if hasattr(node.state, 'cleanup'):
                    node.state.cleanup()
                # Clear all references
                node.state = None
                node.parent = None
                node.children.clear()
                node.path = None
                node.value = 0
                node.visits = 0

        # Clear the tree starting from root
        if self.root:
            clear_node(self.root)
            self.root = None

        # Clear best node reference
        self.best_node = None

        # Clear environment if it has cleanup method
        if hasattr(self.env, 'cleanup'):
            self.env.cleanup()
        
        # Clear all references
        self.env = None
        self.action_counts.clear()
        self.action_counts = None
        
        # Force garbage collection
        
        gc.collect()
        
        self._is_cleaned_up = True

    def __del__(self):
        """Destructor to ensure cleanup is called"""
        self.cleanup()

    def run(self):
        milestone_iterations = [1, 1000, 10000, 100000]
        milestone_results = {}
        done = False
        
        for i in range(self.iterations):
            #print(f"Iteration {i} of {self.iterations} in PlayMCTS")
            # Check if we've hit a milestone
            if i in milestone_iterations:
                # Calculate current results at this milestone
                current_action = self.best_node.path[0] if self.best_node and self.best_node.path else None
                milestone_results[i] = {
                    "action": current_action,
                    "max_depth": self.max_depth,
                    "action_sequence": self.best_node.path if self.best_node else [],
                    "action_score": self.calculate_action_score(),
                    "rewards": self.get_reward_sequence(self.best_node) if self.best_node else [],
                    "done": done
                }
                #print(f"\n--- Milestone at iteration {i} ---")
                #print(f"Best action: {milestone_results[i]['action']}")
                #print(f"Max depth reached: {milestone_results[i]['max_depth']}")
                #print(f"Best action sequence: {milestone_results[i]['action_sequence']}")
                #print(f"Action score: {milestone_results[i]['action_score']}")
                #print(f"Best rewards: {milestone_results[i]['rewards']}")
                #print(f"Done: {milestone_results[i]['done']}")
                #print("----------------------------\n")
            
            node = self.select(self.root)
            done = False
            if node.is_terminal:
                self.update_best_node(node)
                done = True
                
            child = self.expand(node)
            reward = self.rollout(child)
            self.backpropagate(child, reward)

            if child.is_terminal:
                self.update_best_node(node)
                done = True
                
            
            if self.root.is_terminal:
                self.update_best_node(self.root)
                done = True
            
            if self.early_stop and done:
                break

        if self.best_node is None:
            self.best_node = self.best_uct(self.root)

        self.count_actions_in_best_path()
        action_score = self.calculate_action_score()
        best_rewards = self.get_reward_sequence(self.best_node)
        
        if self.early_stop:
            final_results = {
                "action": self.best_node.path[0] if self.best_node.path else None,
                "max_depth": self.max_depth,
                "action_sequence": self.best_node.path,
                "action_score": action_score,
                "rewards": best_rewards,
                "done": done
            }
            return final_results
        
        else:
            return milestone_results

    def update_best_node(self, node):
        if self.best_node is None or node.depth > self.best_node.depth:
            self.best_node = node
        self.max_depth = max(self.max_depth, node.depth)

    def update_root(self, action):
        if action in self.root.children:
            self.root = self.root.children[action]
            self.root.parent = None
        else:
            new_state = self.env.clone()
            new_state.step(action)
            self.root = PlayNode(new_state, parent=None)

    def select(self, node):
        #print(f"Selecting node")
        while node.children and not node.is_terminal:
            if len(node.children) < node.state.action_space.n:
                return node
            node = self.best_uct(node)
        self.update_best_node(node)
        return node

    def expand(self, node):
        #print(f"Expanding node")
        # Get list of untried actions
        untried_actions = [a for a in range(node.state.action_space.n) if a not in node.children]
        
        # Check if there are any untried actions left
        if not untried_actions:
            #print(f"Warning: No untried actions left for expansion - node is fully expanded")
            # Return the node itself or one of its children
            return node if not node.children else random.choice(list(node.children.values()))
        
        # Select a random untried action
        action = random.choice(untried_actions)
        
        # Rest of your existing code
        new_state = node.state.clone()
        obs, reward, done, _, _ = new_state.step(action)
        child = PlayNode(new_state, parent=node, action=action)
        child.reward = reward
        child.is_terminal = done
        node.children[action] = child
        if done:
            #print(f"Terminal state reached during expansion: action={action}, reward={reward}")
            self.update_best_node(child)
        return child

    def rollout(self, node):
        #print(f"Rolling out node")
        state = node.state.clone()
        total_reward = 0
        depth = 0
        _s = state.reset()
        #print(f"In rollout")
        while not state.is_terminal() and depth < 15:
            #print(f"Before rollout policy")
            action = self.rollout_policy(state)
            #print(f"action: {action}")
            #print(f"After rollout policy")
            obs, reward, done, _, _ = state.step(action)
            #print(f"After step")
            #if reward is None:
            #    reward = 0
            total_reward += reward
            depth += 1
            if done:
                #print(f"Terminal state reached during rollout: depth={depth}, total_reward={total_reward}")
                break
        return total_reward

    def rollout_policy(self, state):
        return random.randint(0, state.action_space.n - 1)

    def backpropagate(self, node, reward):
        #print(f"Backpropagating node")
        while node:
            node.visits += 1
            node.value += reward
            node = node.parent

    def count_actions_in_best_path(self):
        self.action_counts = {i: 0 for i in range(self.env.action_space.n)}
        node = self.best_node
        while node.parent:
            self.action_counts[node.action] += 1
            node = node.parent

    def calculate_action_score(self):
        total_actions = sum(count for action, count in self.action_counts.items() if action > 3)
        if total_actions == 0:
            return 0 
        
        action_probabilities = {action: count / total_actions for action, count in self.action_counts.items()}
        
        # Calculate entropy (higher entropy means more balanced action usage)
        entropy = -sum(p * math.log2(p) for p in action_probabilities.values() if p > 0)
        
        # Normalize entropy by the maximum possible entropy (log2 of number of actions)
        max_entropy = math.log2(len(self.action_counts))
        normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0
        
        # Calculate the variance of action probabilities (lower variance means more balanced)
        mean_prob = 1 / len(self.action_counts)
        variance = sum((p - mean_prob) ** 2 for p in action_probabilities.values()) / len(self.action_counts)
        
        # Combine metrics (you can adjust these weights)
        score = 1 * normalized_entropy# + 0.3 * (1 - variance)
        
        return score

    def best_uct(self, node):
        return max(node.children.values(), key=lambda c: c.uct(self.exploration_weight))

    def best_action(self, node):
        return max(node.children.items(), key=lambda c: c[1].visits)
    
    def best_node(self, node):
        if not node.children:
            return node
        return max(node.children.values(), key=lambda c: c.visits)
    
    def get_action_sequence(self, node):
        sequence = []
        while node.parent:
            sequence.append(node.action)
            node = node.parent
        return list(reversed(sequence))
    
    def get_depth(self, node):
        depth = 0
        while node.parent:
            depth += 1
            node = node.parent
        return depth
    
    def get_reward_sequence(self, node):
        rewards = []
        while node.parent:
            rewards.append(node.reward)
            node = node.parent
        return list(reversed(rewards))

class PlayNode:
    def __init__(self, state, parent=None, action=None):
        self.state = state
        self.parent = parent
        self.action = action
        self.children = {}
        self.visits = 0
        self.value = 0
        self.reward = 0
        self.is_terminal = state.is_terminal()
        self.depth = 0 if parent is None else parent.depth + 1
        self.path = [] if parent is None else parent.path + [action]

    def uct(self, exploration_weight):
        if self.visits == 0:
            return float('inf')
        return self.value / self.visits + exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits)
    
    def cleanup(self):
        """Clean up node resources"""
        self.state = None
        self.parent = None
        self.children.clear()
        self.path = None
        if hasattr(self, 'root') and self.root:
            self.root = None
        if hasattr(self, 'best_node') and self.best_node:
            self.best_node = None

class SimulationNode:
    def __init__(self, state, parent=None, action=None):
        self.state = state
        self.parent = parent
        self.action = action
        self.children = {}
        self.visits = 0
        self.value = 0
        self.is_terminal = False

    def uct(self, exploration_weight):
        if self.visits == 0:
            return float('inf')
        return self.value / self.visits + exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits)
    
class Node:
    def __init__(self, state, mechanics, all_mechanics, parent=None):
        self.state = state
        self.mechanics = mechanics
        self.parent = parent
        self.children = {}
        self.visits = 0
        self.value = 0
        self.is_terminal = False
        self.untried_mechanics = set(m for m in all_mechanics if m not in self.mechanics)
        self.reward_per_mechanic = 0
        self.reward = 0
        self.max_reward = 0
    def uct(self, exploration_weight):
        if self.visits == 0:
            return float('inf')
        # Print the UCT value before returning
        uct_value = self.value / self.visits + exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits)
        #print(f"UCT value for node {self}: {uct_value:.4f}")
        return uct_value

    def __repr__(self):
        return f"Node(mechanics={self.mechanics}, visits={self.visits}, value={self.value:.2f}, reward={self.reward:.2f})"

def evaluate_new_mechanic(env, focus_mechanic):
    mcts = MCTS(env.clone(), focus_mechanic, iterations=500, simulation_iterations=75, exploration_weight=1.81)
    stats = mcts.run()

    #print("\nFinal MCTS Tree:")
    #mcts.pretty_print_tree()

    #tree_viz = mcts.plot_tree(max_depth=3)
    #tree_viz.render(f"{focus_mechanic}_mcts_tree", format="png", cleanup=True)
    #print(f"Tree visualization saved as {focus_mechanic}_mcts_tree.png")

    all_nodes = mcts.get_all_nodes()
    
    # Group nodes by the length of their mechanics
    nodes_by_length = {}
    for node in all_nodes:
        if node != mcts.root:
            length = len(node.mechanics)
            if length not in nodes_by_length:
                nodes_by_length[length] = []
            nodes_by_length[length].append(node)

    # Function to get top n nodes
    def get_top_n(nodes, key_func, n=1):
        return sorted(nodes, key=key_func, reverse=True)[:n]

    # Print top performers for each length
    #for length, nodes in sorted(nodes_by_length.items()):
    #    #print(f"\nTop performers for {length} mechanics:")

    #    #print("\nTop by visits:")
    #    #for node in get_top_n(nodes, lambda x: x.visits):
    #    #    print(f"{node.mechanics}: visits={node.visits}, value={node.value:.2f}, reward={node.reward:.2f}")

    #    #print("\nTop by value:")
    #    #for node in get_top_n(nodes, lambda x: x.value):
    #    #    print(f"{node.mechanics}: visits={node.visits}, value={node.value:.2f}, reward={node.reward:.2f}")

    #    #print("\nTop by reward:")
    #    #for node in get_top_n(nodes, lambda x: x.reward):
    #    #    print(f"{node.mechanics}: visits={node.visits}, value={node.value:.2f}, reward={node.reward:.2f}")

    def get_depth(node):
        depth = 0
        while node.parent:
            depth += 1
            node = node.parent
        return depth

    # Calculate mechanic combinations and pairing counts
    mechanic_combinations = {}
    for node in all_nodes:
        if node != mcts.root:
            depth = get_depth(node)
            depth_factor = 1 + (depth / 10)
            combination = tuple(node.mechanics)
            mechanic_combinations[combination] = mechanic_combinations.get(combination, 0) + (node.visits * depth_factor)

    total_value = sum(mechanic_combinations.values())
    mechanic_combinations = {k: v / total_value for k, v in mechanic_combinations.items()}
    
    pairing_counts = {f"{i+1}_mechanic": 0 for i in range(len(env.mechanic_to_action))}
    
    for combination, usage in mechanic_combinations.items():
        if usage > 0: 
            pairing_counts[f"{len(combination)}_mechanic"] += 1

    # Calculate main tree stats
    total_nodes = len(all_nodes)
    avg_depth = sum(node.value for node in all_nodes) / total_nodes if total_nodes > 0 else 0
    max_depth = max(node.value for node in all_nodes)
    avg_visits = sum(node.visits for node in all_nodes) / total_nodes if total_nodes > 0 else 0
    max_visits = max(node.visits for node in all_nodes)

    #print("\nMain Tree Stats:")
    #print(f"Total nodes: {total_nodes}")
    #print(f"Average depth: {avg_depth:.2f}")
    #print(f"Max depth: {max_depth}")
    #print(f"Average visits: {avg_visits:.2f}")
    #print(f"Max visits: {max_visits}")

    #print(f"Best child: {mcts.best_child(mcts.root)}")

    return pairing_counts, mechanic_combinations, all_nodes


def fitness_function(mechanic_combination, usage_stats, pairing_counts):
    # Convert mechanic_combination to a tuple for dictionary lookup
    combo = tuple(sorted(mechanic_combination))
    
    # 1. Usage percentage
    usage = usage_stats.get(combo, 0)
    #print(f"usage percentage: {usage}")
    # 2. Pairing diversity
    pairing_diversity = pairing_counts[f"{len(combo)}_mechanic"] / sum(pairing_counts.values())
    #print(f"pairing diversity: {pairing_diversity}")
    # 3. Combination size bonus
    size_bonus = len(combo)# / len(usage_stats)  # Normalize by total number of combinations
    #print(f"size bonus: {size_bonus}")
    # 4. Synergy bonus (assuming combinations with higher usage have better synergy)
    synergy_bonus = usage * len(combo)
    #print(f"synergy bonus: {synergy_bonus}")
    # 5. Novelty bonus (reward less common combinations)
    novelty_bonus = 1 - usage  # Invert usage to favor less common combinations
    #print(f"novelty bonus: {novelty_bonus}")
    # Weighted sum of factors
    if usage > 0:
        fitness = (
            synergy_bonus #+  # Primary factor
            #0.25 * pairing_diversity +
            #0.25 * size_bonus +
            #0.25 * synergy_bonus +
            #0.05 * novelty_bonus
        )
    else:
        fitness = 0
    return fitness

#for node in all_nodes:
#    print(f"fitness for {node.mechanics}: {fitness_function(node.mechanics, usage_stats, pairing_counts)}")

def calculate_interaction_score(fitnesses, focus_mechanic):
    # Extract all fitness values for combinations including the focus mechanic
    mechanic_fitnesses = [fitness for mechanics, fitness in fitnesses.items() if focus_mechanic in mechanics]
    
    if not mechanic_fitnesses:
        return 0  # Return 0 if no combinations include the focus mechanic
    
    # Calculate metrics
    avg_fitness = np.mean(mechanic_fitnesses)
    max_fitness = np.max(mechanic_fitnesses)
    std_fitness = np.std(mechanic_fitnesses)
    #print(f"avg_fitness: {avg_fitness}")
    #print(f"max_fitness: {max_fitness}")
    #print(f"std_fitness: {std_fitness}")
    # Count high-performing combinations (e.g., above 75th percentile)
    threshold = np.percentile(mechanic_fitnesses, 99)
    #print(f"threshold: {threshold}")
    #for fitness in mechanic_fitnesses:
        #if fitness >= threshold:
        #    print(f"fitness: {fitness} Above threshold")
        #else:
        #    print(f"fitness: {fitness} Below threshold")
    high_performing_count = sum(1 for fitness in mechanic_fitnesses if fitness >= threshold)
    #print(f"high_performing_count: {high_performing_count}")
    # Normalize the high-performing count
    normalized_high_performing = high_performing_count / len(mechanic_fitnesses)
    #print(f"normalized_high_performing: {normalized_high_performing}")
    # Combine metrics into a single score
    # You can adjust these weights based on what you consider most important
    score = (
        #0.3 * avg_fitness +
        max_fitness * high_performing_count
        #0.1 * (1 / (std_fitness + 1e-5)) +  # Lower standard deviation is better
        #0.1 * normalized_high_performing
    )
    
    return score


def make(name):

    #register_env()

    str_world = """BBBBBBBBBBB
BAAAAAAAAAB
BAAAOAAAAAB
BA#@OAAAAAB
BA#AAAAAAAB
BBBBBBBBBBB"""
    str_map_wo_chars = """BBBBBBBBBBB
BAAAAAAAAAB
BAAOOAAAAAB
BAAAOAAAAAB
BAAAAAAAAAB
BBBBBBBBBBB"""

    walkables = ['A', 'B']
    interactive_object_tiles = ['O']
    player_tile = '@'
    enemy_tiles = ["#"]
    npc_tiles = ["&"]
    env_image = dict()

    folder_path = r"/gmd/world_tileset_data"


    env_image["A"] = Image.open(r"/gmd/world_tileset_data/td_world_floor_grass_c.png").convert("RGBA")
    env_image["B"] = Image.open(r"/gmd/world_tileset_data/td_world_wall_stone_h_a.png").convert("RGBA")
    env_image["C"] = Image.open(r"/gmd/world_tileset_data/td_world_floor_grass_c.png").convert("RGBA")
    #env_image["D"] = Image.open(r/gmd/world_tileset_data/td_world_floor_grass_c.png").convert("RGBA")
    env_image["O"] = Image.open(r"/gmd/world_tileset_data/td_world_chest.png").convert("RGBA")
    env_image["@"] = Image.open(r"/gmd/character_sprite_data/td_monsters_archer_d1.png").convert("RGBA")
    env_image["#"] = Image.open(r"/gmd/character_sprite_data/td_monsters_witch_d1.png").convert("RGBA")
    env_image["&"] = Image.open(r"/gmd/character_sprite_data/td_monsters_goblin_captain_d1.png").convert("RGBA")

    #env = gym.make(name,
    #               walkable_tiles=walkables, 
    #              tiles_without_char=str_map_wo_chars, 
    #              tiles=env_image, 
    #              str_map_without_chars=str_map_wo_chars, 
    #              str_map=str_world, 
    #              interactive_object_tiles=interactive_object_tiles, 
    #              enemy_tiles=enemy_tiles)
    
    
    
    env = MechEnv(walkable_tiles=walkables, 
                  tiles_without_char=str_map_wo_chars, 
                  tiles=env_image, 
                  str_map_without_chars=str_map_wo_chars, 
                  str_map=str_world, 
                  interactive_object_tiles=interactive_object_tiles, 
                  enemy_tiles=enemy_tiles,
                  render_mode="rgb_array")
    #env.reset = pufferlib.utils.silence_warnings(env.reset)
    env = GymCompatibilityWrapper(env)
    #env = pufferlib.postprocess.EpisodeStats(env)
    return env

def env_dict():
    env_image = dict()

    env_image["A"] = Image.open(r"/gmd/world_tileset_data/td_world_floor_grass_c.png").convert("RGBA")
    env_image["B"] = Image.open(r"/gmd/world_tileset_data/td_world_wall_stone_h_a.png").convert("RGBA")
    env_image["X"] = Image.open(r"/gmd/world_tileset_data/td_world_floor_grass_c.png").convert("RGBA")
    env_image["O"] = Image.open(r"/gmd/world_tileset_data/td_world_chest.png").convert("RGBA")
    env_image["I"] = Image.open(r"/gmd/world_tileset_data/td_world_chest.png").convert("RGBA")
    env_image["C"] = Image.open(r"/gmd/world_tileset_data/td_world_chest.png").convert("RGBA")
    env_image["@"] = Image.open(r"/gmd/character_sprite_data/td_monsters_archer_d1.png").convert("RGBA")
    env_image["#"] = Image.open(r"/gmd/character_sprite_data/td_monsters_witch_d1.png").convert("RGBA")
    env_image["&"] = Image.open(r"/gmd/character_sprite_data/td_monsters_goblin_captain_d1.png").convert("RGBA")
    
    return env_image

def str_map():

    str_world = """BBBBBBBBBBB
BAAAAAAAAAB
BAA@OAXAAAB
BAAAAAICAAB
BAAAAAAAAAB
BBBBBBBBBBB"""

    return str_world

def important_tiles():
    walkables = ['A', 'B']
    interactive_object_tiles = ['O']
    collectible_tiles = ['I', 'C']
    npc_tiles = ["&"]
    player_tile = '@'
    enemy_tiles = ["#"]
    return walkables, interactive_object_tiles, collectible_tiles, npc_tiles, player_tile, enemy_tiles


def make_game(name):

    #register_env()
    str_world = """BBBBBBBBBBB
BAAAAAAAAAB
BAA@OAXAAAB
BAAAAAICAAB
BAAAAAAAAAB
BBBBBBBBBBB"""

    str_map_wo_chars = """BBBBBBBBBBB
BAAAAAAAAAB
BAAAAAAAAAB
BAAAAAAAAAB
BAAAAAAAAAB
BBBBBBBBBBB"""

    walkables = ['A', 'B', 'X']
    interactive_object_tiles = ['O']
    collectible_tiles = ['I', 'C']
    player_tile = '@'
    enemy_tiles = ["#"]
    env_image = dict()

    env_image["A"] = Image.open(r"/gmd/world_tileset_data/td_world_floor_grass_c.png").convert("RGBA")
    env_image["B"] = Image.open(r"/gmd/world_tileset_data/td_world_wall_stone_h_a.png").convert("RGBA")
    env_image["X"] = Image.open(r"/gmd/world_tileset_data/td_world_floor_grass_c.png").convert("RGBA")
    env_image["O"] = Image.open(r"/gmd/world_tileset_data/td_world_chest.png").convert("RGBA")
    env_image["I"] = Image.open(r"/gmd/world_tileset_data/td_world_chest.png").convert("RGBA")
    env_image["C"] = Image.open(r"/gmd/world_tileset_data/td_world_chest.png").convert("RGBA")
    env_image["@"] = Image.open(r"/gmd/character_sprite_data/td_monsters_archer_d1.png").convert("RGBA")
    env_image["#"] = Image.open(r"/gmd/character_sprite_data/td_monsters_witch_d1.png").convert("RGBA")
    env_image["&"] = Image.open(r"/gmd/character_sprite_data/td_monsters_goblin_captain_d1.png").convert("RGBA")
    # Added new tiles for additional interactive objects and collectibles
    env_image["D"] = Image.open(r"/gmd/world_tileset_data/td_world_chest.png").convert("RGBA")  # New interactive object
    env_image["E"] = Image.open(r"/gmd/world_tileset_data/td_world_chest.png").convert("RGBA")  # New collectible

    env = GameEnv(walkable_tiles=walkables, 
                  tiles_without_char=str_map_wo_chars, 
                  tiles=env_image, 
                  str_map_without_chars=str_map_wo_chars, 
                  str_map=str_world, 
                  interactive_object_tiles=interactive_object_tiles, 
                  enemy_tiles=enemy_tiles,
                  collectible_tiles=collectible_tiles,
                  render_mode="rgb_array")
    env = GymCompatibilityWrapper(env)
    return env

class GymCompatibilityWrapper(gym.Wrapper):
    def step(self, action):
        obs, reward, done, truncated, info = self.env.step(action)
        return obs, reward, done, truncated, info

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        return obs, {}
    
    def clone(self):
        new_env = GymCompatibilityWrapper(self.env.clone())
        new_env.action_space = self.action_space
        new_env.observation_space = self.observation_space
        new_env.reward_range = self.reward_range
        new_env.metadata = self.metadata
        return new_env
    
# Usage
#env = make('MechEnv-v1')
#env.reset()  
#focus_mechanic = "flip"  
#pairing_counts, usage_stats, all_nodes = evaluate_new_mechanic(env, focus_mechanic)
#print("\n")
#print(f"{focus_mechanic.capitalize()} mechanic pairing results:")
#for i, (key, value) in enumerate(pairing_counts.items()):
#    print(f"Paired with {i} other mechanic{'s' if i != 1 else ''}: {value}")
#print("\nUsage statistics:")
#for mechanic, usage in usage_stats.items():
#    print(f"{mechanic}: {usage:.2%}")


#fitnesses = {}
#for node in all_nodes:
#    fitnesses[str(node.mechanics)] = fitness_function(node.mechanics, usage_stats, pairing_counts)

#for mechanics, fitness in list(fitnesses.items())[:5]:
#    print(f"fitness for {mechanics}: {fitness}")
#print(".")
#print(".")
#print(".")
#print("\n")

#interaction_score = calculate_interaction_score(fitnesses, focus_mechanic)
#print("\n")
#print(f"{focus_mechanic.capitalize()} interaction score: {interaction_score:.4f}")


#def calculate_mechanic_fitness(latest_mechanic_number, best_action_sequence, best_rewards):
#    if latest_mechanic_number not in best_action_sequence:
#        return -1
#    total_reward = sum(best_rewards)
#    if total_reward == 0:
#        return 0
#    score = 0
#    for i, action in enumerate(best_action_sequence):
#        if action == latest_mechanic_number:
#            score += best_rewards[i]
#    score = score / total_reward
#    return score

def calculate_mechanic_fitness(latest_mechanic_numbers, best_action_sequence, best_rewards):
    if not any(mechanic in best_action_sequence for mechanic in latest_mechanic_numbers):
        return -1
    total_reward = sum(best_rewards)
    if total_reward == 0:
        return 0
    score = 0
    for i, action in enumerate(best_action_sequence):
        if action in latest_mechanic_numbers:
            score += best_rewards[i]
    score = score / total_reward
    return score

#game_env = make_game('MechEnv-v1')  # Initialize with appropriate parameters

#done = False
#total_reward = 0
#game_env.reset()
#mcts = PlayMCTS(game_env, iterations=1000000)
#start_time = time.time()
#best_action, max_depth, best_action_sequence, action_score, best_rewards = mcts.run()
#end_time = time.time()
#execution_time = end_time - start_time
#latest_mechanic_number = len(game_env.mechanic_to_action) + 3 - 1
#print(f"latest_mechanic_number: {latest_mechanic_number}")
#total_mechanics = game_env.action_space.n - 4
#print(f"total_mechanics: {total_mechanics}")
#print(f"max_depth: {max_depth}")
#print(f"best_action_sequence: {best_action_sequence}")
#print(f"best_rewards_sequence: {best_rewards}")
#latest_mechanic_score = calculate_mechanic_fitness(latest_mechanic_number, best_action_sequence, best_rewards)
#print(f"latest_mechanic_score: {latest_mechanic_score}")
#print(f"total_reward: {sum(best_rewards)}")
#print(f"action_score: {action_score:.4f}")
#print(f"execution_time: {execution_time:.4f} seconds")
#print("total score: ", sum(best_rewards) * max_depth * execution_time)
# Use the best action in your game
#obs, reward, done, _, _ = game_env.step(best_action)
#print(f"done: {done}")
#print(f"best_action: {best_action}")
#print(f"reward: {reward}")
#game_env = make_game('MechEnv-v1')
#done = False
#total_reward = 0
#iterations = 0
#
#mcts = PlayMCTS(game_env, iterations=1000)
#
#while not done and iterations < 1000:
#    best_action = mcts.run()
#    print(f"best_action: {best_action}")
#    
#    obs, reward, done, _, info = game_env.step(best_action)
#    total_reward += reward
#    print(f"reward: {total_reward}")
#    
#    game_env.render()
#    iterations += 1
#    
#    if not done:
#        mcts.update_root(best_action)
#
#print(f"Game over. Total reward: {total_reward}")
