import sys
# from networkx import hub_matrix
import numpy as np
from typing import Optional
import ast
import inspect
import types
# import OberservedState for ZeroAgent
from src.GOPS.ground_truth_models import ObservedState
from src.searchlight.gameplay.agents import MCTSAgent
from ..llm.model import LLMModel, SystemMessage, HumanMessage, OpenAIModel
from .prompts import *
from src.searchlight.headers import ActorActionEnumerator, ForwardTransitor, InformationFunction, InformationPrior, PolicyPredictor, ValueHeuristic

class ZeroAgent(MCTSAgent):
    '''
    MCTS agent that uses a LLM to construct all internal models
    '''
    def __init__(self, players: set[int], player: int, is_partial: bool, game_description: str, observed_state_description: str, llm_model: LLMModel, *args, rng: np.random.Generator = np.random.default_rng(), **kwargs):
        self.llm_model = llm_model
        self.is_partial = is_partial
        self.game_description = game_description
        self.observed_state_description = observed_state_description

        # player order is random but with the current player first
        self.player_order = list(players)
        self.player_order.remove(player)
        rng.shuffle(self.player_order)
        self.player_order.insert(0, player)
        self.player_order = tuple(self.player_order)
        
        # add retry loop here to handle potential errors in constructing the hidden state class
        # first construct the hidden state class
        max_retries = 5
        retry = 0
        while retry < max_retries:
            try:
                self.hidden_state_class, hidden_state_code_string = self.construct_hidden_state_class(game_description, observed_state_description)
                print("===Hidden state constructed!===")
                globals()["HiddenState"] = self.hidden_state_class
                break
            except Exception as e:
                print(e)
                retry += 1

        # then construct the value heuristic
        max_retries = 5
        retry = 0
        while retry < max_retries:
            try:
                self.value_heuristic = self.construct_value_heuristic(game_description, hidden_state_code_string, inspect.getsource(ValueHeuristic))()
                print("===Value heuristic constructed!===")
                break
            except Exception as e:
                print(e)
                retry += 1

        # then construct the actor action enumerator
        max_retries = 5
        retry = 0
        while retry < max_retries:
            try:
                self.actor_action_enumerator = self.construct_actor_action_enumerator(game_description, hidden_state_code_string, inspect.getsource(ActorActionEnumerator))(player_order=self.player_order)
                print("===Actor action enumerator constructed!===")
                break
            except Exception as e:
                print(e)
                retry += 1
  
        # then construct the forward transitor
        max_retries = 5
        retry = 0
        while retry < max_retries:
            try:
                self.forward_transitor = self.construct_forward_transitor(game_description, hidden_state_code_string, inspect.getsource(ForwardTransitor))()
                print("===Forward transitor constructed!===")
                break
            except Exception as e:
                print(e)
                retry += 1
        
        # then construct the information function
        max_retries = 5
        retry = 0
        while retry < max_retries:
            try:
                self.information_function = self.construct_information_function(game_description, hidden_state_code_string, observed_state_description, inspect.getsource(InformationFunction))()
                print("===Information function constructed!===")
                break
            except Exception as e:
                print(e)
                retry += 1
        
        # then construct the information prior
        max_retries = 5
        retry = 0
        while retry < max_retries:
            try:
                self.information_prior = self.construct_information_prior(game_description, hidden_state_code_string, inspect.getsource(InformationPrior), observed_state_description)()
                print("===Information prior constructed!===")
                break
            except Exception as e:
                print(e)
                retry += 1

        
        super().__init__(
            players=players,
            player=player,
            actor_action_enumerator=self.actor_action_enumerator,
            forward_transitor=self.forward_transitor,
            information_function=self.information_function,
            information_prior=self.information_prior,
            value_heuristic=self.value_heuristic,
            *args,
            rng=rng,
            **kwargs
        )

    @staticmethod
    def extract_valid_python_code(code: str) -> str:
        """
        Extracts valid Python code from input, ignoring non-Python parts, including text
        outside of the Python code blocks.
        """
        python_code_lines = []
        inside_code_block = False
        
        for line in code.splitlines():
            if line.strip().startswith("```python"):
                inside_code_block = True
                continue
            if line.strip().startswith("```") and inside_code_block:
                inside_code_block = False
                continue
            if inside_code_block:
                python_code_lines.append(line)
        
        return "\n".join(python_code_lines)

    @staticmethod
    def extract_imports(code: str) -> str:
        """
        Extracts all import statements from the Python code.
        """
        import_lines = []
        for line in code.splitlines():
            line = line.strip()
            if line.startswith('import ') or line.startswith('from '):
                import_lines.append(line.strip())
        return "\n".join(import_lines)

    @staticmethod
    def parse_class_from_code_ast_with_imports(code: str, class_name: str) -> Optional[str]:
        """
        Parses out a class definition from code, handling non-Python parts and 
        prepending import statements to the class definition.
        """
        python_code = ZeroAgent.extract_valid_python_code(code)  # Extract valid Python code

        if not python_code:
            return None
        
        try:
            tree = ast.parse(python_code)
            imports = ZeroAgent.extract_imports(python_code)  # Extract import statements
            for node in ast.walk(tree):
                if isinstance(node, ast.ClassDef) and node.name == class_name:
                    # Extract the relevant lines of code from the source
                    lines = python_code.splitlines()
                    start_line = node.lineno - 1
                    end_line = node.end_lineno
                    if imports is not None:
                        # indent to imports
                        imports = "\n".join([f"    {line}" for line in imports.splitlines()])
                        class_definition = lines[start_line] + '\n' + imports + "\n" + "\n".join(lines[start_line+1:end_line])
                    else:
                        class_definition = "\n".join(lines[start_line:end_line])

                    return class_definition
        except Exception as e:
            print(f"Error parsing code: {e}")
        
        return None
    
    def construct_hidden_state_class(self, game_description: str, observed_state_description: str) -> tuple[type, str]:
        '''
        Construct the hidden state class from the observation class description.

        Returns:
            class: HiddenState class or None if the class could not be constructed
            code_string: HiddenState class code string
        '''
        sys_prompt = SystemMessage(content=SYS_PROMPT)
        human_prompt = HumanMessage(content=f"{game_description}\n\nAn observed state (information set) in the game is defined as follows:\n{observed_state_description}\n\n Create a frozen `@dataclass` called `HiddenState` that fully captures hidden information in the game, including unobserved simultaneous actions. Ensure it's hashable.")
        messages = [sys_prompt, human_prompt]
        response = self.llm_model.generate(messages)
        try:
            hidden_state_code = ZeroAgent.parse_class_from_code_ast_with_imports(response, "HiddenState")
            assert hidden_state_code is not None
            return ZeroAgent.execute_class(hidden_state_code, "HiddenState"), hidden_state_code
        except Exception as e:
            print(e)

        raise Exception("Could not construct hidden state class")

    def construct_actor_action_enumerator(self, game_description: str, hidden_state_description: str, action_class_description: str) -> type:
        '''
        Construct the actor action enumerator from the action class description
        '''
        sys_prompt = SystemMessage(content=SYS_PROMPT)
        human_prompt = HumanMessage(content=f"{game_description}\n\nA hidden state in the game is defined as follows:\n{hidden_state_description}\n\nWrite an actor-action enumerator `CustomActorActionEnumerator` for this game that inherits from the `ActorActionEnumerator` class. Include all docstings from the parent class:\n\n{action_class_description}")
        messages = [sys_prompt, human_prompt]
        actor_action_enumerator_description = self.llm_model.generate(messages)
        try:
            actor_action_enumerator_description = ZeroAgent.parse_class_from_code_ast_with_imports(actor_action_enumerator_description, "CustomActorActionEnumerator")
            assert actor_action_enumerator_description is not None
            # CustomActorActionEnumerator depends on HiddenState
            return ZeroAgent.execute_class(actor_action_enumerator_description, "CustomActorActionEnumerator", HiddenState=self.hidden_state_class)
        except Exception as e:
            print(e)

        raise Exception("Could not construct actor action enumerator")
    
    def construct_forward_transitor(self, game_description: str, hidden_state_description: str, transitor_class_description: str) -> type:
        '''
        Construct the forward transitor from the transitor class description
        '''
        sys_prompt = SystemMessage(content=SYS_PROMPT)
        human_prompt = HumanMessage(content=f"{game_description}\n\nA hidden state in the game is defined as follows:\n{hidden_state_description}\n\nWrite a forward transitor `CustomForwardTransitor` for this game that inherits from the `ForwardTransitor` class. Include all docstings from the parent class:\n\n{transitor_class_description}")
        messages = [sys_prompt, human_prompt]
        forward_transitor_description = self.llm_model.generate(messages)
        try:
            forward_transitor_description = ZeroAgent.parse_class_from_code_ast_with_imports(forward_transitor_description, "CustomForwardTransitor")
            assert forward_transitor_description is not None
            return ZeroAgent.execute_class(forward_transitor_description, "CustomForwardTransitor", HiddenState=self.hidden_state_class)
        except Exception as e:
            print(e)

        raise Exception("Could not construct forward transitor")
    
    def construct_value_heuristic(self, game_description: str, hidden_state_description: str, value_heuristic_class_description: str) -> type:
        '''
        Construct the value heuristic from the value heuristic class description
        '''
        sys_prompt = SystemMessage(content=SYS_PROMPT)
        human_prompt = HumanMessage(content=f"{game_description}\n\nA hidden state in the game is defined as follows:\n{hidden_state_description}\n\nWrite a value heuristic `CustomValueHeuristic` for this game that inherits from the `ValueHeuristic` class. Include all docstings from the parent class:\n\n{value_heuristic_class_description}")
        messages = [sys_prompt, human_prompt]
        value_heuristic_description = self.llm_model.generate(messages)
        try:
            value_heuristic_description = ZeroAgent.parse_class_from_code_ast_with_imports(value_heuristic_description, "CustomValueHeuristic")
            assert value_heuristic_description is not None
            # also depend on HiddenState
            return ZeroAgent.execute_class(value_heuristic_description, "CustomValueHeuristic", HiddenState=self.hidden_state_class) 
        except Exception as e:
            print(e)

        raise Exception("Could not construct value heuristic")
            
        
    def construct_information_function(self, game_description: str, hidden_state_description: str, observed_state_description: str, information_function_class_description: str) -> type:
        '''
        Construct the information function from the information function description
        '''
        sys_prompt = SystemMessage(content=SYS_PROMPT)
        human_prompt = HumanMessage(content=f"{game_description}\n\nA hidden state in the game is defined as follows:\n{hidden_state_description}\n\nAn observed state (information set) in the game is defined as follows:\n{observed_state_description}\n\nWrite an information function `CustomInformationFunction` for this game that inherits from the `InformationFunction` class. Include all docstings from the parent class:\n\n{information_function_class_description}")
        messages = [sys_prompt, human_prompt]
        information_function_description = self.llm_model.generate(messages)
        try:
            information_function_description = ZeroAgent.parse_class_from_code_ast_with_imports(information_function_description, "CustomInformationFunction")
            assert information_function_description is not None
            exec(information_function_description)
            # CustomInformationFunction = local_namespace["CustomInformationFunction"]
            return locals()['CustomInformationFunction'] 
        except Exception as e:
            print(e)

        raise Exception("Could not construct information function")
    
    @staticmethod
    def execute_class(code_string: str, class_name: str, **kwargs) -> type:
        """
        Execute a code string and return a specified class.

        Args:
        code_string (str): The Python code as a string.
        class_name (str): The name of the class to return.

        Returns:
        type: The class object specified by class_name.

        Raises:
        AttributeError: If the specified class is not found in the executed code.
        """
        # Import necessary modules
        import dataclasses
        import typing
        
        # Create a new dictionary to serve as the local namespace
        local_namespace = {
            'dataclass': dataclasses.dataclass,
            'field': dataclasses.field,
            'Tuple': typing.Tuple,
            'FrozenSet': typing.FrozenSet,
            'ForwardTransitor': ForwardTransitor,
            'ActorActionEnumerator': ActorActionEnumerator,
            'InformationFunction': InformationFunction,
            'InformationPrior': InformationPrior,
            'PolicyPredictor': PolicyPredictor,
            'ValueHeuristic': ValueHeuristic,
        }
        for key, value in kwargs.items():
            local_namespace[key] = value
        
        # Execute the code string in the local namespace
        exec(code_string, local_namespace)
        
        # Try to get the class from the local namespace
        if class_name in local_namespace:
            return local_namespace[class_name]
        else:
            raise AttributeError(f"Class '{class_name}' not found in the executed code.")
        
    def construct_information_prior(self, game_description: str, hidden_state_description: str, prior_class_description: str, observation_class_description: str) -> type:
        '''
        Construct the information prior from the prior class description
        '''
        sys_prompt = SystemMessage(content=SYS_PROMPT)
        human_prompt = HumanMessage(content=f"{game_description}\n\nA hidden state in the game is defined as follows:\n{hidden_state_description}\n\nAn observation in the game is defined as follows:\n{observation_class_description}\n\nWrite an information prior `CustomInformationPrior` for this game that inherits from the `InformationPrior` class. Include all docstings from the parent class:\n\n{prior_class_description}")
        messages = [sys_prompt, human_prompt]
        new_prior_class = self.llm_model.generate(messages)
        try:
            new_prior_class = ZeroAgent.parse_class_from_code_ast_with_imports(new_prior_class, "CustomInformationPrior")
            assert new_prior_class is not None
            exec(new_prior_class)
            return locals()['CustomInformationPrior']
        except Exception as e:
            print(e)

        raise Exception("Could not construct information prior")
