import pynisher
import pandas as pd
import numpy as np
import sklearn

from sklearn.ensemble import RandomForestRegressor
from typing import List, Union, Dict, Any
from time import time

from grammar.constants import MAXIMIZE, MINIMIZE



class State:
    """
    A class representing a general state
    """
    def clone(self) -> 'State':
        """
        Clones the current state and returns a new instance.
        """
        raise NotImplementedError('Subclasses should to implement this.')
    
    def get_possible_actions(self) -> List[str]:
        """
        Returns the list of possible actions for the current state.
        """
        raise NotImplementedError('Subclasses should to implement this.')
    
    def take_action(self, action: Union[str, List[str]]) -> 'State':
        """
        Updates the current state by taking a given action.
        """
        raise NotImplementedError('Subclasses should to implement this.')
    
    def is_terminal(self) -> bool:
        """
        Checks whether the current state is a terminal state.
        """
        raise NotImplementedError('Subclasses should to implement this.')

    def get_reward(self, X_sample: np.ndarray, y_sample: np.ndarray, cv: Union[None, int]) -> float:
        """
        Computes the reward for the current state.
        """
        raise NotImplementedError('Subclasses should to implement this.')
        
    
class TreeState(State):
    
    def __init__(self, node, parent_sequence, key, y=None):
        self.node = node
        self.sequence = parent_sequence + ("" + node.name if parent_sequence else node.name)
        self.key = key
        self.y = y
        self.PENALTY = float("-inf")

    def clone(self) -> 'TreeState':
        return TreeState(self.node, self.sequence, self.key, y=self.y)

    def get_possible_actions(self) -> List[str]:
        return list(self.node.children.keys())

    def take_action(self, action: str) -> 'TreeState':
        next_node = self.node.children.get(action)
        if next_node is None:
            raise ValueError(f"Acción '{action}' no válida desde el estado actual.")
        return TreeState(next_node, self.sequence, self.key)

    def is_terminal(self) -> bool:
        return self.node.is_terminal()
    
    def is_valid(self) -> bool:
        return self.y is not None

    def get_reward(self, X_sample: np.ndarray, y_sample: np.ndarray, cv: Union[None, int]) -> float:
        
        # PATCH: StandardScalerStandardScalerComponent
        self.sequence = self.sequence.replace('StandardScalerStandardScalerComponent','StandardScalerComponent')
        self.sequence = self.sequence.replace('PCAPCA','PCA')
        
        m = X_sample[(X_sample.exploration_step_name==self.sequence)]

        self.y = y_sample.loc[m.index]
        self.score = y_sample.loc[m.index][self.key].values[0]
        
        return self.score
    
    def __eq__(self, other: Any) -> bool:
        return self.sequence == other.sequence

    def __hash__(self) -> int:
        return hash(self.sequence)


