#import gym

import numpy as np
import os
import json
import time
import imageio
from PIL import Image, ImageDraw, ImageFont
from rembg import remove
import random
import io
import functools
import shimmy
import copy
#import pufferlib
#import pufferlib.emulation
#import pufferlib.environments
#import pufferlib.environment
#import pufferlib.postprocess
#import pufferlib.utils
import gc
from graphviz import Digraph
import itertools

#

import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs.registration import register

#from graphviz import Digraph


from tqdm import tqdm, trange
from ribs.archives import GridArchive
from ribs.visualize import grid_archive_heatmap

import numpy as np
import random
import sys
import matplotlib.pyplot as plt

import json
import pickle

from init_mechanics import mech_1, mech_2, mech_3, mech_4, mech_5, mech_6, mech_7, mech_8
from init_games import game_class_1, game_class_2, make_game_1, make_game_2

from llm_emitter import MechanicLLMEmitter, GameLLMEmitter
from create_mechanics import mechanics_test
from create_games import get_games_scores

from utils import extract_function_name

from configs import Configs

from utils import extract_function_name, filter_functions_by_name
    
import math
import random
import pprint
import ray

import pandas as pd
from pathlib import Path
from configs import Configs

class LLMMCTS:
    def __init__(self, mechanic_emitters, focus_mechanic, iterations=10000, simulation_iterations=1000, exploration_weight=1.41, max_mechanics=5, max_tree_depth=5, simulation_depth=1, save_path=None, generation=0):

        self.focus_mechanic = focus_mechanic
        self.iterations = iterations
        self.simulation_iterations = simulation_iterations
        self.exploration_weight = exploration_weight
        self.mechanic_emitters = mechanic_emitters
        self.max_mechanics = max_mechanics
        self.mechanics, self.mechanics_code = self._get_mechanics(num_mechanics=self.max_mechanics)

        #print(f"\nself.mechanics: {self.mechanics}")
        #print(f"\nself.mechanics_code: {self.mechanics_code}")

        #self.mechanics = list(env.mechanic_to_action.keys())
        #self.mechanic_to_action = env.mechanic_to_action 
        self.root = None 
        self.last_node = False
        self.max_tree_depth = max_tree_depth
        self.simulation_depth = simulation_depth
        self.generation = generation
        #self.make_new_mechanic_prob = min(0.5, ((self.generation + 1) / 100))
        self.make_new_mechanic_prob = 0.5

        self.cache_file = f"{save_path}/mechanics_cache.csv"
        self.save_path = save_path
        self.games_created = 0
        self._is_cleaned_up = False
        self.configs = Configs()
        self.load_cache()
        

    def cleanup(self):
        """Clean up resources and memory"""
        if self._is_cleaned_up:
            return

        # Clear the cache DataFrame
        if hasattr(self, 'cache_df'):
            if isinstance(self.cache_df, pd.DataFrame):
                self.cache_df.drop(self.cache_df.index, inplace=True)
            self.cache_df = None

        # Recursively clear the tree
        def clear_node(node):
            if node:
                # First cleanup children recursively
                for child in list(node.children.values()):
                    clear_node(child)
                # Clear all references and collections
                node.children.clear()
                node.parent = None
                #if hasattr(node, 'mechanics') and node.mechanics:
                #    node.mechanics.clear()
                #    node.mechanics = None
                if hasattr(node, 'untried_mechanics'):
                    node.untried_mechanics.clear()
                # Clear numerical values
                node.visits = 0
                node.value = 0
                node.reward = 0
                node.max_reward = 0

        if self.root:
            clear_node(self.root)
            self.root = None

        # Clear all other attributes
        #self.mechanics = None
        self.mechanics_code = None
        self.focus_mechanic = None
        self.last_node = None
        self.cache_file = None
        self.save_path = None

        # Force garbage collection
        
        gc.collect()
        
        self._is_cleaned_up = True
    def __del__(self):
        """Destructor to ensure cleanup is called"""
        self.cleanup()

    def save_tree(self, filename):
        """Save the MCTS tree to both a pickle file and a visualization"""
        if self.save_path is None:
            #print("Warning: save_path not set, using current directory")
            self.save_path = "."
            
        # Save pickle file
        tree_data = self._serialize_node(self.root)
        save_file = Path(self.save_path) / f"{filename}.pkl"
        with open(save_file, 'wb') as f:
            pickle.dump(tree_data, f)
        #print(f"Tree data saved to {save_file}")
        
        # Create tree visualization
        dot = Digraph(comment='MCTS Tree')
        dot.attr(rankdir='TB')  # Top to bottom layout
        dot.attr('node', shape='box', style='rounded,filled', fillcolor='lightblue')
        
        def add_node_to_graph(node, parent_id=None):
            node_id = str(id(node))
            
            # Format node label with mechanics and stats
            mechanics_str = '\n'.join(node.mechanics) if node.mechanics else 'Root'
            
            # Calculate UCT if possible
            uct_value = "N/A"
            if node.parent and node.visits > 0:
                uct_value = node.uct(self.exploration_weight)
                uct_value = f"{uct_value:.2f}"
            
            label = f"{mechanics_str}\n\nVisits: {node.visits}\nValue: {node.value:.2f}\nUCT: {uct_value}"
            
            # Color coding based on performance
            if node.visits > 0:
                # Normalize value between 0 and 1
                normalized_value = node.value / (node.visits + 1)  # Add 1 to avoid division by zero
                # Create color gradient from red (low) to green (high)
                color = f"#{int(255 * (1-normalized_value)):02x}{int(255 * normalized_value):02x}80"
                dot.node(node_id, label, fillcolor=color)
            else:
                dot.node(node_id, label)
            
            if parent_id:
                dot.edge(parent_id, node_id)
            
            # Recursively add children
            for mechanic, child in node.children.items():
                add_node_to_graph(child, node_id)
        
        # Build the graph
        add_node_to_graph(self.root)
        
        # Save visualization
        #viz_file = Path(self.save_path) / f"{filename}_viz"
        #dot.render(viz_file, format='png', cleanup=True)
        #print(f"Tree visualization saved to {viz_file}.png")

    def load_tree(self, filename):
        """Load the MCTS tree from a file"""
        load_file = Path(self.save_path) / f"{filename}.pkl"
        
        with open(load_file, 'rb') as f:
            tree_data = pickle.load(f)
        
        self.root = self._deserialize_node(tree_data)
        print(f"Tree loaded from {load_file}")

    def _serialize_node(self, node):
        """Convert a node and its children into a serializable format"""
        if node is None:
            return None
            
        serialized = {
            'mechanics': node.mechanics,
            'visits': node.visits,
            'value': node.value,
            'reward': node.reward,
            'max_reward': node.max_reward,
            'is_terminal': node.is_terminal,
            'untried_mechanics': list(node.untried_mechanics),
            'children': {}
        }
        
        for mechanic, child in node.children.items():
            serialized['children'][mechanic] = self._serialize_node(child)
            
        return serialized

    def _deserialize_node(self, data, parent=None):
        """Reconstruct a node and its children from serialized data"""
        if data is None:
            return None
            
        node = Node(data['mechanics'], self.mechanics, parent)
        node.visits = data['visits']
        node.value = data['value']
        node.reward = data['reward']
        node.max_reward = data['max_reward']
        node.is_terminal = data['is_terminal']
        node.untried_mechanics = set(data['untried_mechanics'])
        
        for mechanic, child_data in data['children'].items():
            node.children[mechanic] = self._deserialize_node(child_data, node)
            
        return node

    def load_cache(self):
        """Load or create the mechanics cache DataFrame"""
        if os.path.exists(self.cache_file):
            self.cache_df = pd.read_csv(self.cache_file)
            # Convert string representation of mechanics back to set
            self.cache_df['mechanics'] = self.cache_df['mechanics'].apply(eval)
        else:
            self.cache_df = pd.DataFrame(columns=[
                'mechanics',
                'reward_entropy',
                'depth',
                'dones'
            ])
            self.save_cache()

    def save_cache(self):
        """Save the cache DataFrame to disk"""
        self.cache_df.to_csv(self.cache_file, index=False)

    def get_cached_results(self, mechanics_set):
        """Check if results exist in cache for given mechanics combination"""
        cache_match = self.cache_df[self.cache_df['mechanics'].apply(
            lambda x: set(x) == set(mechanics_set)
        )]
        if not cache_match.empty:
            row = cache_match.iloc[0]
            return (
                row['reward_entropy'],
                row['depth'],
                row['dones']
            )
        return None

    def add_to_cache(self, mechanics_set, reward_entropy, depth, dones):
        """Add new results to cache"""
        new_row = pd.DataFrame([{
            'mechanics': list(mechanics_set),
            'reward_entropy': reward_entropy,
            'depth': depth,
            'dones': dones
        }])
        self.cache_df = pd.concat([self.cache_df, new_row], ignore_index=True)
        self.save_cache()

    def _get_mechanics(self, num_mechanics):
        mechanics_names = []
        mechanics_names.append(extract_function_name(self.focus_mechanic))
        mechanics_code = self.focus_mechanic
        #print(f"extract_function_name(self.focus_mechanic): {extract_function_name(self.focus_mechanic)}")
        for _ in range(num_mechanics-1):
            #print(f"\n_mechanics_{_}:\n {_mechanics[_]}")
            selected_mechanic = self.mechanic_emitters[0].ask_random_solutions(1, mechanics_names)
            if selected_mechanic in mechanics_names:
                continue
            mechanics_names.append(extract_function_name(selected_mechanic))
            mechanics_code = mechanics_code + "\n" + selected_mechanic
        
        return mechanics_names, mechanics_code
    
    def _get_new_mechanic(self, mechanics_in_focus, itr):
        new_mechanic, _x = self.mechanic_emitters[0].ask("compatibility_mutation", mechanics_in_focus)
        futures = [mechanics_test.remote(mechanic, generation=itr) for mechanic in new_mechanic]
        results = ray.get(futures)
        #print(f"\nRESULTS in llm_mcts.py:\n", results)
        standalone_mech_fitness , mechanics_behaviour, updated_game_mech_class = results[0]
        if mechanics_behaviour is None or updated_game_mech_class is None:
            return None, None
        mechanic_name = extract_function_name(new_mechanic[0])
        mechanics_data = {"mechanic_name": mechanic_name,
                          "mechanic": new_mechanic[0],
                          "standalone_fitness": standalone_mech_fitness, 
                          "behaviour": mechanics_behaviour}
        #print("\nNew Mechanic Data:")
        #print(f"  Mechanic Name: {mechanics_data['mechanic_name']}")
        #print(f"  Standalone Fitness: {mechanics_data['standalone_fitness']}")
        #print(f"  Behavior: {mechanics_data['behaviour']}")
        #print(f"  Mechanic Code:\n{mechanics_data['mechanic']}\n")
        return new_mechanic[0], mechanics_data

    def pretty_print_tree(self, node=None, depth=0):
        if node is None:
            node = self.root
        
        indent = "  " * depth
        print(f"{indent}{node}")
        
        for mechanic, child in node.children.items():
            print(f"{indent}  {mechanic} ->")
            self.pretty_print_tree(child, depth + 1)

    def plot_tree(self, max_depth=3):
        dot = Digraph(comment='MCTS Tree')
        dot.attr(rankdir='TB', size='8,8')
        
        def add_nodes_edges(node, parent_id=None, depth=0):
            #if depth > max_depth:
            #    return

            node_id = str(id(node))
            label = f"{node.mechanics[-1] if node.mechanics else 'Root'}\n" \
                    f"Visits: {node.visits}\n" \
                    f"Value: {node.value:.2f}\n" \
                    f"Reward: {node.reward:.2f}"
            
            dot.node(node_id, label)
            
            if parent_id:
                dot.edge(parent_id, node_id)
            
            for mechanic, child in node.children.items():
                add_nodes_edges(child, node_id, depth + 1)
        
        add_nodes_edges(self.root)
        return dot

    def run(self):
        self.root = Node([extract_function_name(self.focus_mechanic)], self.mechanics)
        start_time = time.time()
        new_mechanic_data = []
        #print(f"New Mechanic Probability: {self.make_new_mechanic_prob}")
        for _i in range(self.iterations):
            #print(f"Iteration {_i} / {self.iterations} in Evaluation MCTS")
            node = self.select(self.root)
            #print(f"Mechanics in selected node: {node.mechanics}")
            if node.untried_mechanics:
                child, mechanic_data = self.expand(node, _i)
                #print(f"Mechanics in expanded node: {child.mechanics}")
                depth, reward, dones = self.simulate(child, _i)
                #print(f"Depth and reward in simulated node: {depth}, {reward}")
                self.backpropagate(child, depth, reward, dones)
                if mechanic_data is not None:
                    new_mechanic_data.append(mechanic_data)
            else:
                # Always run SimulationMCTS for leaf nodes
                #simulation_mcts = SimulationMCTS(node.state, node.mechanics, self.simulation_iterations)
                #depth, reward = simulation_mcts.run()
                game_mechs = filter_functions_by_name(self.mechanics_code, node.mechanics)
                #futures = get_games_scores.remote([game_mechs], generation=_i)
                #reward, game_bds, game_class_individual, _mechanic_fitness, win_condition, game_name, llm_actions, mechanics_to_action, game_data, dones = ray.get(futures)
                reward, game_bds, game_class_individual, _mechanic_fitness, win_condition, game_name, llm_actions, mechanics_to_action, game_data, dones = get_games_scores([game_mechs], generation=_i)

                if reward is None:
                    reward = 0
                    depth = 0
                else:
                    depth = len(llm_actions)
                #print(f"Depth and reward in simulated node: {depth}, {reward}")
                self.backpropagate(node, depth, reward, dones)

        print(f"time taken to run MCTS: {time.time() - start_time}")
        return self.best_child(self.root), self.games_created, new_mechanic_data

    def expand(self, node, itr):
        
        if isinstance(node.dones, str):
            node.dones = eval(node.dones)

        if len(node.mechanics) < self.max_tree_depth and len(node.children) < self.configs.max_children:
            
            #all_failed = len(node.dones) > 0 and not any(node.dones)
            #
            #if all_failed:
            #    other_nodes_more_visited = all(
            #        child.visits > node.visits 
            #        for sibling in (node.parent.children.values() if node.parent else [])
            #        for child in sibling.children.values()
            #        if child != node
            #    )
            #    if not other_nodes_more_visited:
            #        return node
            
            if random.random() < self.make_new_mechanic_prob:
                new_mechanic = "Make_New_Mechanic"
            else:
                new_mechanic = "From_Archive"
            #new_mechanic = random.choice(["From_Archive"] + ["Make_New_Mechanic"])
            #new_mechanic = "Make_New_Mechanic"
            new_mechanic_data = None
            from_archive = False



            if new_mechanic == "From_Archive":
                #print(f"Using Mechanics from Archive in node: {node.mechanics}")
                new_mechanic = random.choice(list(node.untried_mechanics))
                new_mechanics = node.mechanics + [new_mechanic]
                from_archive = True

            elif new_mechanic == "STOP":
                #print(f"Stopping at node: {node.mechanics}")
                return node, None
            
            elif new_mechanic == "Make_New_Mechanic":
                #print(f"Making new mechanic in the node: {node.mechanics}")
                new_mechanic, new_mechanic_data = self._get_new_mechanic(node.mechanics, itr)
                if new_mechanic is None:
                    #print(f"Using Mechanics from Archive in node: {node.mechanics}")
                    new_mechanic = random.choice(list(node.untried_mechanics))
                    from_archive = True
                    new_mechanics = node.mechanics + [new_mechanic]
                else:
                    new_mechanics = node.mechanics + [new_mechanic_data["mechanic_name"]]
                
            #new_state = node.state.clone()
            #new_state.reset()
            
            child = Node(new_mechanics, self.mechanics, parent=node)
            node.children[new_mechanic] = child
            #print(f"Expanded node: {child}")
        
            best_node = self.find_best_node(self.root)
            #print(f"Best node so far: {best_node}")
            
            if from_archive:
                node.untried_mechanics.remove(new_mechanic)
            
            return child, new_mechanic_data
        else:
            #print(f"Node not expanded: reward ({node.max_reward}) <= 75% of number of mechanics ({len(node.mechanics) * 0.75}) of total len: {len(node.mechanics)}")
            return node, None


    def select(self, node):
        while node.children and not node.is_terminal:
            if (node.untried_mechanics and len(node.children) < self.configs.max_children) or node.visits <= 1: 
                return node
            node = self.best_uct(node)
        return node

    def simulate(self, node, itr):
        #state = node.state.clone()
        total_depth = 0
        total_reward = 0
        used_mechanics = set(node.mechanics)
        available_mechanics = set(self.mechanics) - used_mechanics

        
        
        #print(f"Available mechanics: {available_mechanics}")
        
        max_depth = 0
        max_reward = -1
        reward_per_mechanic = 0
        reward_per_mechanic_counter = 1
        list_mechanics_to_action = []
        current_simulation_depth = 0
        dones = []

        while available_mechanics and current_simulation_depth < self.simulation_depth:

            cache_flag = False
            #print(f"Mechanics being used: {used_mechanics}")
            cached_results = self.get_cached_results(list(used_mechanics))
            #cached_results = None
            if cached_results is not None:
                #print(f"Cache hit for mechanics: {list(used_mechanics)} with reward: {cached_results}")
                cached_results = list(cached_results)
                reward = []
                reward.append(cached_results[0])
                llm_actions = cached_results[1]
                dones = cached_results[2]
            else:

                game_mechs = filter_functions_by_name(self.mechanics_code, list(used_mechanics))
                #print(f"Simulating mechanics: {game_mechs}")
                #futures = get_games_scores.remote([game_mechs], generation=itr)
                #reward, game_bds, game_class_individual, _mechanic_fitness, win_condition, game_name, llm_actions, mechanics_to_action, game_data, dones = ray.get(futures)
                reward, game_bds, game_class_individual, _mechanic_fitness, win_condition, game_name, llm_actions, mechanics_to_action, game_data, dones = get_games_scores([game_mechs], generation=itr)
                list_mechanics_to_action.append(mechanics_to_action)
            #simulation_mcts = SimulationMCTS(state, list(used_mechanics), self.simulation_iterations)
            #depth, reward = simulation_mcts.run()
            if reward is None:
                reward = 0
                depth = 0
                dones = [False]
                cache_flag = False
            else:
                reward = sum(reward)
                if cached_results:
                    depth = len([llm_actions])
                else:
                    depth = len(llm_actions)
                self.games_created += 1
                cache_flag = True
            reward_per_mechanic += reward/(len(used_mechanics))
            
            if list(used_mechanics)[1:] and cache_flag:
                self.add_to_cache(list(used_mechanics)[1:], reward, depth, dones)
                #print(f"Cached results for mechanics combination: {list(used_mechanics)[1:]}")
            
            new_mechanic = random.choice(list(available_mechanics))
            max_depth = max(max_depth, depth)
            max_reward = max(max_reward, reward)
 
            total_depth += depth
            total_reward += reward
            # Choose a new mechanic to add
            
            used_mechanics.add(new_mechanic)
            available_mechanics.remove(new_mechanic)

            #if len(available_mechanics) == 0:
            #    self.last_node = True
            #    break
            
            reward_per_mechanic_counter += 1
            current_simulation_depth += 1
        #node.reward_per_mechanic = reward_per_mechanic/reward_per_mechanic_counter
        #reward_per_sim = total_reward/reward_per_mechanic_counter
        return max_depth, max_reward, dones
    
    def backpropagate(self, node, depth, reward, dones):
        while node:
            node.visits += 1
            node.value += reward
            node.reward += reward
            node.max_reward = reward
            node.dones = dones
            node = node.parent

    def best_uct(self, node):
        return max(node.children.values(), key=lambda c: c.uct(self.exploration_weight))

    def best_child(self, node):
        return max(node.children.items(), key=lambda c: c[1].visits)
    #def best_child(self, node):
        #if not node.children:
        #    return None, None

        def score(child):
            # Combine visits, value, and reward, then normalize by the number of nodes
            return child.value/child.visits if child.visits > 0 else 0#(child.visits + child.value + child.reward) / len(child.mechanics)

        return max(node.children.items(), key=lambda c: score(c[1]))
    def find_best_node(self, node):
        #if not node.children:
        #    return node
        return self.best_child(node)

    def top_n_children(self, node, number_of_children=5):
        if not node.children:
            return [], [], []
        
        sorted_by_visits = sorted(node.children.items(), key=lambda c: c[1].visits, reverse=True)[:number_of_children]
        sorted_by_value = sorted(node.children.items(), key=lambda c: c[1].value, reverse=True)[:number_of_children]
        sorted_by_reward = sorted(node.children.items(), key=lambda c: c[1].reward, reverse=True)[:number_of_children]
        
        return sorted_by_visits, sorted_by_value, sorted_by_reward
    
    def get_all_nodes(self):
        all_nodes = []
        def traverse(node):
            all_nodes.append(node)
            for child in node.children.values():
                traverse(child)
        traverse(self.root)
        return all_nodes
    
    def calculate_shapley_values(self, node):
        """Calculate Shapley values for mechanics in the given node"""
        mechanics = node.mechanics
        n = len(mechanics)
        shapley_values = {m: 0.0 for m in mechanics}
        
        def get_value(mechanic_subset):
            """Get the cached value for a subset of mechanics"""
            if not mechanic_subset:
                return 0
            value = node.value
            if value is None:
                return 0
            return value

        # Calculate marginal contributions for each mechanic
        for mechanic in mechanics:
            marginal_sum = 0
            other_mechanics = [m for m in mechanics if m != mechanic]
            
            # For each possible subset size
            for i in range(len(other_mechanics) + 1):
                # For each possible subset of that size
                for subset in itertools.combinations(other_mechanics, i):
                    # Calculate coalition size
                    s = len(subset)
                    
                    # Calculate coefficient for this coalition size
                    coef = math.factorial(s) * math.factorial(n - s - 1) / math.factorial(n)
                    
                    # Calculate marginal contribution
                    subset_value = get_value(list(subset))
                    subset_with_mechanic = list(subset) + [mechanic]
                    subset_with_mechanic_value = get_value(subset_with_mechanic)
                    
                    marginal_contribution = subset_with_mechanic_value - subset_value
                    marginal_sum += coef * marginal_contribution
                    
            shapley_values[mechanic] = marginal_sum
        
        return shapley_values

    def get_shapley_values_for_mechanics(self, mechanics_list):
        
        # Get all nodes in the tree
        all_nodes = self.get_all_nodes()
        
        # Initialize tracking variables for all mechanics
        shapley_values = {}
        appearance_counts = {}
        
        # Initialize for focus mechanic
        focus_mechanic = extract_function_name(self.focus_mechanic)
        mechanics_to_track = [focus_mechanic]
        
        # Add new mechanics to track
        for mechanic_data in mechanics_list:
            if mechanic_data is not None:
                mechanics_to_track.append(mechanic_data["mechanic_name"])
        
        # Initialize tracking for all mechanics
        for mechanic in mechanics_to_track:
            shapley_values[mechanic] = 0.0
            appearance_counts[mechanic] = 0
        
        # Calculate Shapley values for each node where any tracked mechanic appears
        for node in all_nodes:
            if node != self.root:  # Skip root node
                node_shapley_values = self.calculate_shapley_values(node)
                
                # Add Shapley values for tracked mechanics
                for mechanic in mechanics_to_track:
                    if mechanic in node.mechanics and mechanic in node_shapley_values:
                        shapley_values[mechanic] += node_shapley_values[mechanic]
                        appearance_counts[mechanic] += 1
        
        # Calculate averages
        average_shapley_values = {}
        for mechanic in mechanics_to_track:
            if appearance_counts[mechanic] > 0:
                average_shapley_values[mechanic] = shapley_values[mechanic] / appearance_counts[mechanic]
            else:
                average_shapley_values[mechanic] = 0.0
            
            # Print results
            print(f"\nAverage Shapley Value for {mechanic}:")
            print(f"Value: {average_shapley_values[mechanic]:.4f} (appeared in {appearance_counts[mechanic]} nodes)")
        
        return average_shapley_values, appearance_counts
    
    def get_tree_rank_corr(self):
        """Calculate the average value across all nodes in the tree"""
        # Get all nodes in the tree
        all_nodes = self.get_all_nodes()
        
        # Skip root node and initialize tracking variables
        nodes = [node for node in all_nodes if node != self.root]
        total_value = 0.0
        total_visits = 0
        max_value = float('-inf')
        min_value = float('inf')
        
        # Calculate statistics
        for node in nodes:
            # Calculate average value for this node
            node_avg_value = node.value / node.visits if node.visits > 0 else 0
            total_value += node_avg_value
            total_visits += node.visits
            
            # Track max and min values
            max_value = max(max_value, node_avg_value)
            min_value = min(min_value, node_avg_value)
        
        # Calculate overall average
        average_value = total_value / len(nodes) if nodes else 0
        
        # Print results
        #print("\nTree Value Statistics:")
        #print(f"Average Value: {average_value:.4f}")
        #print(f"Maximum Value: {max_value:.4f}")
        #print(f"Minimum Value: {min_value:.4f}")
        #print(f"Total Nodes: {len(nodes)}")
        #print(f"Total Visits: {total_visits}")
        
        return average_value
            
class Node:
    def __init__(self, mechanics, all_mechanics, parent=None):
        #self.state = state
        self.mechanics = mechanics
        self.mechanic_to_action = {}
        self.parent = parent
        self.children = {}
        self.visits = 0
        self.value = 0
        self.is_terminal = False
        self.untried_mechanics = set(m for m in all_mechanics if m not in self.mechanics)
        self.reward_per_mechanic = 0
        self.reward = 0
        self.max_reward = np.inf
        self.dones = []
    def uct(self, exploration_weight):
        if self.visits == 0:
            return float('inf')
        # Print the UCT value before returning
        uct_value = self.value / self.visits + exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits)
        #print(f"UCT value for node {self}: {uct_value:.4f}")
        return uct_value

    def __repr__(self):
        return f"Node(mechanics={self.mechanics}, visits={self.visits}, value={self.value:.2f}, reward={self.reward:.2f})"


#configs = Configs()
#MODEL = configs.model
#EXPERIMENT = configs.experiment
#total_itrs = configs.generations
#max_num_mechs = configs.max_num_mechs_to_add
#init_games = 1
#
#no_game_archive = True
#
#random_selection = False
#
#
#batch_size = 1
#
#min_bound_1 = 0
#max_bound_1 = 45
#min_bound_2 = 0
#max_bound_2 = 1
#
#game_min_bound_1 = -1
#game_max_bound_1 = 1
#game_min_bound_2 = 0
#game_max_bound_2 = 25
#
#mechanic_archive = GridArchive(solution_dim=1,
#                      dims=(25, 25),
#                      ranges=[(min_bound_1, max_bound_1), (min_bound_2, max_bound_2)],
#                      dtype={"solution": np.dtype('O'), "objective": np.float32, "measures": np.float32})
#
#
#game_archive = GridArchive(solution_dim=1,
#                      dims=(25, 25),
#                      ranges=[(game_min_bound_1, game_max_bound_1), (game_min_bound_2, game_max_bound_2)],
#                      dtype={"solution": np.dtype('O'), "objective": np.float32, "measures": np.float32})
#
#
#
#
#mutation_individuals = 3
#mechanic_emitters = [
#            MechanicLLMEmitter(mechanic_archive, 
#            #x0=np.array(["persian cats"]), 
#            initial_solutions=np.array([[mech_1], [mech_2], [mech_3], [mech_4], [mech_5], [mech_6], [mech_7], [mech_8]]), 
#            #initial_solutions=np.array([[mech_1]]), 
#            bounds=None,
#            mutation_individuals=mutation_individuals, 
#            batch_size=batch_size,
#            operator="openai", 
#            operator_kwargs={"temperature": 1.0}, 
#            mutation_prompt="", 
#            model=MODEL) #for _ in range(4)
#        ]





def evaluate_new_mechanic(mechanic_emitters, mechanic_archive, focus_mechanic, llm_mcts_iterations, simulation_depth, save_path=None, generation=None):
    try:
        mcts = LLMMCTS(mechanic_emitters, focus_mechanic, iterations=llm_mcts_iterations, simulation_iterations=75, exploration_weight=1.81, max_mechanics=5, max_tree_depth=5, simulation_depth=simulation_depth, save_path=save_path, generation=generation)
        best_node, games_created, new_mechanic_data = mcts.run()
        shapley_values, appearance_counts = mcts.get_shapley_values_for_mechanics(new_mechanic_data)

        # Access individual values
        shapley_fitness = shapley_values[extract_function_name(mcts.focus_mechanic)]

        # Access new mechanics' values
        for mechanic_data in new_mechanic_data:
            if mechanic_data is not None:
                mechanic_name = mechanic_data["mechanic_name"]
                mechanic_data["shapley_fitness"] = shapley_values[mechanic_name]
        
        tree_rank_corr = mcts.get_tree_rank_corr()

        #print("\nFinal MCTS Tree:")
        #mcts.pretty_print_tree()

        #tree_viz = mcts.plot_tree(max_depth=3)
        #tree_viz.render(f"{focus_mechanic}_mcts_tree", format="png", cleanup=True)
        #print(f"Tree visualization saved as {focus_mechanic}_mcts_tree.png")

        mcts.save_tree(f"{extract_function_name(focus_mechanic)}_mcts_tree")
        all_nodes = mcts.get_all_nodes()
        
        # Group nodes by the length of their mechanics
        nodes_by_length = {}
        for node in all_nodes:
            if node != mcts.root:
                length = len(node.mechanics)
                if length not in nodes_by_length:
                    nodes_by_length[length] = []
                nodes_by_length[length].append(node)

        # Function to get top n nodes
        def get_top_n(nodes, key_func, n=1):
            return sorted(nodes, key=key_func, reverse=True)[:n]

        # Print top performers for each length
        #for length, nodes in sorted(nodes_by_length.items()):
        #    print(f"\nTop performers for {length} mechanics:")

        #    print("\nTop by visits:")
        #    for node in get_top_n(nodes, lambda x: x.visits):
        #        print(f"{node.mechanics}: visits={node.visits}, value={node.value:.2f}, reward={node.reward:.2f}")

        #    print("\nTop by value:")
        #    for node in get_top_n(nodes, lambda x: x.value):
        #        print(f"{node.mechanics}: visits={node.visits}, value={node.value:.2f}, reward={node.reward:.2f}")

        #    print("\nTop by reward:")
        #    for node in get_top_n(nodes, lambda x: x.reward):
        #        print(f"{node.mechanics}: visits={node.visits}, value={node.value:.2f}, reward={node.reward:.2f}")

        def get_depth(node):
            depth = 0
            while node.parent:
                depth += 1
                node = node.parent
            return depth

        # Calculate mechanic combinations and pairing counts
        mechanic_combinations = {}
        for node in all_nodes:
            if node != mcts.root:
                depth = get_depth(node)
                depth_factor = 1 + (depth / 10)
                combination = tuple(node.mechanics)
                mechanic_combinations[combination] = mechanic_combinations.get(combination, 0) + (node.visits * depth_factor)

        total_value = sum(mechanic_combinations.values())
        mechanic_combinations = {k: v / total_value for k, v in mechanic_combinations.items()}
        
        #print(f"mechanic_combinations for {focus_mechanic}: {mechanic_combinations}")

        pairing_counts = {f"{i+1}_mechanic": 0 for i in range(mcts.max_mechanics)}
        
        for combination, usage in mechanic_combinations.items():
            if usage > 0: 
                pairing_counts[f"{len(combination)}_mechanic"] += 1

        # Calculate main tree stats
        total_nodes = len(all_nodes)
        avg_depth = sum(node.value for node in all_nodes) / total_nodes if total_nodes > 0 else 0
        max_depth = max(node.value for node in all_nodes)
        avg_visits = sum(node.visits for node in all_nodes) / total_nodes if total_nodes > 0 else 0
        max_visits = max(node.visits for node in all_nodes)

        print(f"\nMain Tree Stats for {extract_function_name(focus_mechanic)}:")
        stats = {
            "total_nodes": total_nodes,
            "games_created": games_created,
            "avg_depth": round(avg_depth, 2),
            "max_depth": max_depth,
            "avg_visits": round(avg_visits, 2),
            "max_visits": max_visits,
            "root_value": mcts.root.value,
            "root_visits": mcts.root.visits
        }
        print("\n".join(f"{k}: {v}" for k,v in stats.items()))

        return pairing_counts, mechanic_combinations, total_nodes, stats, games_created, new_mechanic_data, shapley_fitness, tree_rank_corr
    
    finally:
        if 'mcts' in locals():
            mcts.cleanup()


def fitness_function(mechanic_combination, usage_stats, pairing_counts):
    # Convert mechanic_combination to a tuple for dictionary lookup
    combo = tuple(sorted(mechanic_combination))
    usage = usage_stats.get(combo, 0)
    if usage > 0:
        fitness = (
            usage #+  # Primary factor
            #0.25 * pairing_diversity +
            #0.25 * size_bonus +
            #0.25 * synergy_bonus +
            #0.05 * novelty_bonus
        )
    else:
        fitness = 0
    return fitness

#for node in all_nodes:
#    print(f"fitness for {node.mechanics}: {fitness_function(node.mechanics, usage_stats, pairing_counts)}")

def calculate_interaction_score(fitnesses, focus_mechanic):
    # Extract all fitness values for combinations including the focus mechanic
    mechanic_fitnesses = [fitness for mechanics, fitness in fitnesses.items() if focus_mechanic in mechanics]
    
    if not mechanic_fitnesses:
        return 0  # Return 0 if no combinations include the focus mechanic
    
    # Calculate metrics
    avg_fitness = np.mean(mechanic_fitnesses)
    max_fitness = np.max(mechanic_fitnesses)
    std_fitness = np.std(mechanic_fitnesses)
    #print(f"avg_fitness: {avg_fitness}")
    #print(f"max_fitness: {max_fitness}")
    #print(f"std_fitness: {std_fitness}")
    ## Count high-performing combinations (e.g., above 75th percentile)
    threshold = np.percentile(mechanic_fitnesses, 99)
    #print(f"threshold: {threshold}")
    #for fitness in mechanic_fitnesses:
        #if fitness >= threshold:
        #    print(f"fitness: {fitness} Above threshold")
        #else:
        #    print(f"fitness: {fitness} Below threshold")
    high_performing_count = sum(1 for fitness in mechanic_fitnesses if fitness >= threshold)
    #print(f"high_performing_count: {high_performing_count}")
    # Normalize the high-performing count
    normalized_high_performing = high_performing_count / len(mechanic_fitnesses)
    #print(f"normalized_high_performing: {normalized_high_performing}")
    # Combine metrics into a single score
    # You can adjust these weights based on what you consider most important
    score = (
        #0.3 * avg_fitness +
        max_fitness * high_performing_count
        #0.1 * (1 / (std_fitness + 1e-5)) +  # Lower standard deviation is better
        #0.1 * normalized_high_performing
    )
    
    return score

# Usage
#env = make('MechEnv-v1')
#env.reset()  

#if mechanic_emitters[0]._initial_solutions is not None and mechanic_archive.empty:
#    sol_list = []
#    for init_solution in mechanic_emitters[0]._initial_solutions:
#        sol_list.append(init_solution[0])
#    objective, measure, _ = mechanics_test(sol_list, init_mechs=True,generation="initial")
#    mechanic_archive.add(mechanic_emitters[0]._initial_solutions, objective, measure)   
#
#random_mechs = mechanic_emitters[0].ask_random_solutions(1)
#focus_mechanic = random_mechs
#pairing_counts, usage_stats, all_nodes = evaluate_new_mechanic(mechanic_emitters, focus_mechanic)
#
#print("\n")
#print(f"{extract_function_name(focus_mechanic).capitalize()} mechanic pairing results:")
#
#for i, (key, value) in enumerate(pairing_counts.items()):
#    print(f"Paired with {i} other mechanic{'s' if i != 1 else ''}: {value}")
#
#print("\nUsage statistics:")
#for mechanic, usage in usage_stats.items():
#    print(f"{mechanic}: {usage:.2%}")
#
#
#fitnesses = {}
#for node in all_nodes:
#    fitnesses[str(node.mechanics)] = fitness_function(node.mechanics, usage_stats, pairing_counts)
#
##for mechanics, fitness in list(fitnesses.items())[:5]:
##    print(f"fitness for {mechanics}: {fitness}")
##print(".")
##print(".")
##print(".")
##print("\n")
#
#interaction_score = calculate_interaction_score(fitnesses, extract_function_name(focus_mechanic))
#print("\n")
#print(f"{extract_function_name(focus_mechanic).capitalize()} interaction score: {interaction_score:.4f}")

@ray.remote
def run_mechanic_mcts(mechanic_emitters, mechanic_archive, focus_mechanic, llm_mcts_iterations, simulation_depth, save_path=None, generation=None):


    #random_mechs = mechanic_emitters[0].ask_random_solutions(1)
    #focus_mechanic = random_mechs
    pairing_counts, usage_stats, all_nodes, mcts_stats, games_created, new_mechanic_data, shapley_fitness, tree_rank_corr= evaluate_new_mechanic(mechanic_emitters, mechanic_archive, focus_mechanic, llm_mcts_iterations, simulation_depth, save_path, generation)



    print("\n")
    print(f"{extract_function_name(focus_mechanic).capitalize()} Average Shapley fitness: {shapley_fitness}")
    #print(f"Average Tree rank correlation: {tree_rank_corr}")


    return mcts_stats, all_nodes, games_created, new_mechanic_data, shapley_fitness#, tree_rank_corr


