import logging, re, requests
import wikipedia, dateparser
import numpy as np
from typing import Any, Dict
from datetime import datetime
from wikipedia.exceptions import DisambiguationError, PageError

from .utils import exact_match_score
from .constants import *
from .constrained_decoding import DecodingMonitor
from .config import Config

###
#
###

_ACT_SN = 'action'

logger = logging.getLogger(LOGGER_NAME)

class StateHandler:

    def __init__(self, config : Config): self.config = config

    def __call__(self, monitor : DecodingMonitor, *args: Any, **kwds: Any) -> str: 
        raise NotImplementedError(f'Method __call__ for {type(self)} is not implemented!')

    def adjust_prompt_args(self, prompt_kwargs : Dict): pass

    def adjust_monitor(self, monitor : DecodingMonitor): pass

class ToolHandler(StateHandler):

    def __init__(self, config : Config):
        super().__init__(config)
        self._action_types = dict([(action_type, self.get_tool(action_type)) for action_type in config.action_types])
        self._action_names = [type(act).NAME.strip() for _, act in sorted(self._action_types.items())]

    def __call__(self, monitor : DecodingMonitor, *args: Any, **kwds: Any) -> str:
        history = monitor.history
        action_name, action_input = history[-2][1], history[-1][1]
        return self._get_observation(action_name, action_input, *args, **kwds)

    def _get_observation(self, action_name : str, action_input : str, *args: Any, **kwds: Any):

        if not action_name in self._action_types: self._action_types[action_name] = self.get_tool(action_name)
        assert action_name in self._action_types, f'Action type {action_name} is not valid!'
        action = self._action_types[action_name]

        action_result = action(action_input, *args, **kwds)
        for action_obj in self._action_types.values():
            action_obj.update_state(action)

        return action_result

    def adjust_prompt_args(self, prompt_kwargs : Dict):
        if self._action_names:
            prompt_kwargs['tool_labels'] = '[' + ', '.join(self._action_names) + ']'
            prompt_kwargs['tool_descriptions'] = '\n'.join([f'{type(act).NAME}: {type(act).DESCRIPTION}' for _, act in sorted(self._action_types.items())])
            return prompt_kwargs

    def adjust_monitor(self, monitor : DecodingMonitor):
        if self._action_names:
            monitor_states = monitor.states
            if _ACT_SN in monitor_states:
                logger.debug(f'Adding output constraints [{", ".join(self._action_names)}] to state {_ACT_SN}')
                monitor_states[_ACT_SN].add_constraints(self._action_names)

    def get_tool(self, action_type : str):
        if action_type == 'Search': return SearchTool(self.config)
        elif action_type == 'Lookup': return LookupTool(self.config)
        elif action_type == 'Calculator': return CalculatorTool(self.config)
        elif action_type == 'CalendarTodayDate': return CalendarTodayDateTool(self.config)
        elif action_type == 'CalendarNumberBusinessdaysBetween': return CalendarNumberBusinessdaysBetweenTool(self.config)
        elif action_type == 'CalendarNumberDaysBetween': return CalendarNumberDaysBetweenTool(self.config)

        raise ValueError(f'Unknown tool type {action_type}')

class RedundancyAwareToolHandler(ToolHandler):

    def __init__(self, config : Config):
        super().__init__(config)
        self._action_types = dict([(action_type, self.get_tool(action_type)) for action_type in config.action_types])
        self._action_names = [type(act).NAME.strip() for _, act in sorted(self._action_types.items())]

    def __call__(self, monitor : DecodingMonitor, *args: Any, **kwds: Any) -> str:
        history = monitor.history
        action_name, action_input = history[-2][1], history[-1][1]
        
        if self._action_names:
            monitor_states = monitor.states
            if _ACT_SN in monitor_states:
                logger.debug(f'Removing output constraints [{action_name}] from state {_ACT_SN}')
                monitor_states[_ACT_SN].remove_constraints([action_name])

        return self._get_observation(action_name, action_input)

class AblationToolHandler(ToolHandler):

    def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)

    def __call__(self, monitor: DecodingMonitor, *args: Any, **kwds: Any) -> str:
        last_entry = monitor.history[-1][1]
        split_lst = re.split('|'.join([re.escape(x) for x in ['Action:', 'Action Input:']]), last_entry)
        try:
            action_name, action_input = split_lst[-2].strip(), split_lst[-1].strip()
            return self._get_observation(action_name, action_input)
        except Exception as e:
            if type(e) in [IndexError, ValueError]: return 'Error in parsing action'
            else: raise e

class EvaluatorHandler(StateHandler):

    def __init__(self, *args, **kwargs): 
        super().__init__(*args, **kwargs)

    def update_answer(self, answer : str):
        self._answer = answer

    def __call__(self, monitor: DecodingMonitor, *args: Any, **kwds: Any) -> str:
        return ('Correct', True) if self._check_equiv(monitor.history[-1][1]) else ('Incorrect', False)

    def _check_equiv(self, x):
        return exact_match_score(x, self._answer)

class AblationEvaluatorHandler(EvaluatorHandler):

    def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)

    def __call__(self, monitor: DecodingMonitor, *args: Any, **kwds: Any) -> str:
        last_entry = monitor.history[-1][1]
        split_lst = re.split('|'.join([re.escape(x) for x in ['Answer:']]), last_entry)
        answer = split_lst[-1].strip()
        return ('Correct', True) if self._check_equiv(monitor.history[-1][1]) else ('Incorrect', False)

class UserInputHandler(StateHandler):

    def __init__(self, config : Config):
        super().__init__(config)

    def __call__(self, monitor : DecodingMonitor, *args: Any, **kwds: Any) -> str:
        history = monitor.history
        # should be last item on history
        print(f'\n\nChat Bot:\n{history[-1][1]}\n')
        response = input('Type your response below. Press \'Enter\' once your response is complete.\n\n').strip()
        return response

class RestateInputHandler(StateHandler):

    def __init__(self, config : Config):
        super().__init__(config)

    def __call__(self, monitor : DecodingMonitor, *args: Any, **kwds: Any) -> str:
        history = monitor.history
        # should be first item on history
        response = '\"' + history[0][-1] + '\"'
        return response

###
#
###

class Tool:

    def __init__(self, config : Config):
        self.config = config

    def __call__(self, *args: Any, **kwds: Any) -> Any:
        raise NotImplementedError(f'Action type {type(self).__name__} does not have action execution implemented!')

    def update_state(self, action_type : Any): pass

class SearchTool(Tool):

    NAME = 'Search'
    DESCRIPTION = 'which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.'

    def __init__(self, config : Config):
        super().__init__(config)
        self._context = None

    def __call__(self, action_input : str, attempts=3, *args: Any, **kwds: Any) -> Any:
        try:
            self._context = wikipedia.summary(action_input, auto_suggest=False)
            text = self._context.replace('\n', ' ')
            response = ' '.join([s.strip() + '.' for s in text.split('. ') if s.strip()][:2])
        except DisambiguationError as e:
            response = f'Could not find {action_input}. Similar: ' + ', '.join(e.options[:5])
        except PageError as e:
            search_results = wikipedia.search(action_input)
            if search_results: response = f'Could not find {action_input}. Similar: ' + ', '.join(search_results[:5])
            else: response = f'Could not find {action_input}. No similar options can be found!'
        except KeyError as e:
            response = f'Could not find {action_input}. No similar options can be found!'
        except wikipedia.exceptions.WikipediaException as e:
            if 'Search is currently too busy' in str(e):
                return self(action_input, attempts=attempts - 1, *args, **kwds) if attempts > 0 else 'Search is currently too busy!'
        except requests.exceptions.JSONDecodeError as e:
            return self(re.escape(action_input), attempts=attempts - 1, *args, **kwds) if attempts > 0 else 'JSON decoding error!'

        return response

class LookupTool(Tool):

    NAME = 'Lookup'
    DESCRIPTION = 'which returns the next sentence containing keyword in the current passage.'

    def __init__(self, config : Config):
        super().__init__(config)
        self._context = ''
        self._keyword = ''
        self._lookup_count = 0
        self._lookup_list = []

    def __call__(self, action_input : str, *args: Any, **kwds: Any) -> Any:
        if action_input != self._keyword:
            self._keyword = action_input
            self._lookup_count = 0
            self._construct_lookup_list()
        else:
            self._lookup_count += 1

        if self._lookup_count >= len(self._lookup_list): 
            return 'No more results.' if self._context else f'Cannot use {type(self).NAME} without using {SearchTool.NAME} first!'
        else:
            response = f'(Result {self._lookup_count + 1} / {len(self._lookup_list)}) {self._lookup_list[self._lookup_count]}'
        return response
    
    def update_state(self, action_tool : Tool):
        if type(action_tool) == SearchTool:
            self._context = action_tool._context
            self._lookup_count = 0

    def _construct_lookup_list(self):
        # find all paragraphs
        if self._context is None:
            self._lookup_list = []
        else:
            paragraphs = [p.strip() for p in self._context.split("\n") if p.strip()]
            sentences = [sentence.strip() for paragraph in paragraphs for sentence in paragraph.split('. ')]
            self._lookup_list = [sentence for sentence in sentences if sentence.strip() and self._keyword in sentence]

class CalculatorTool(Tool):

    NAME = 'Calculator'
    DESCRIPTION = 'which performs numerical computations'

    def __init__(self, config : Config):
        super().__init__(config)

    def __call__(self, action_input : str, *args: Any, **kwds: Any) -> Any:
        only_formula = ''.join([c for c in action_input if c in "0123456789*+-/.()"])
        try:
            res = eval(only_formula)
            if type(res) == float and res.is_integer(): res = int(res)
            return str(res)
        except SyntaxError as e:
            return '<<SYNTAX ERROR>>'
        except TypeError as e:
            return '<<TYPE ERROR>>'
        except ZeroDivisionError as e:
            return 'Cannot divide by zero'

class CalendarTodayDateTool(Tool):

    NAME = 'CalendarTodayDate'
    DESCRIPTION = 'which gets the current date'

    def __init__(self, config : Config):
        super().__init__(config)

    def __call__(self, action_input : str, *args: Any, **kwds: Any) -> Any:
        return datetime.today().strftime('%m/%d/%Y')

class CalendarNumberBusinessdaysBetweenTool(Tool):

    NAME = 'CalendarNumberBusinessdaysBetween'
    DESCRIPTION = 'which gets the number of business days between two dates'

    def __init__(self, config : Config):
        super().__init__(config)

    def __call__(self, action_input : str, *args: Any, **kwds: Any) -> Any:
        """ 
        This is going to assume the format will be "<date1> to <date2>"
        """
        date_components = action_input.split('to')
        try:
            if len(date_components) != 2: raise ValueError(f'Date provided is malformed!')
            date1, date2 = dateparser.parse(date_components[0]), dateparser.parse(date_components[1])
            if date1 is None or date2 is None: raise ValueError(f'Date provided is malformed!')
            return np.busday_count(date1, date2)
        except ValueError as e:
            return '<<ERROR>>'

class CalendarNumberDaysBetweenTool(Tool):

    NAME = 'CalendarNumberDaysBetween'
    DESCRIPTION = 'which gets the number of days between two dates'

    def __init__(self, config : Config):
        super().__init__(config)

    def __call__(self, action_input : str, *args: Any, **kwds: Any) -> Any:
        """ 
        This is going to assume the format will be "<date1> to <date2>"
        """
        date_components = action_input.split('to')
        try:
            if len(date_components) != 2: raise ValueError(f'Date provided is malformed!')
            date1, date2 = dateparser.parse(date_components[0]), dateparser.parse(date_components[1])
            if date1 is None or date2 is None: raise ValueError(f'Date provided is malformed!')
            delta = date2 - date1
            return abs(delta.days)
        except ValueError as e:
            return '<<ERROR>>'

###
#
###

STATE_HANDLERS = {
    'observation' : ToolHandler,
    'redundancy-observation' : RedundancyAwareToolHandler,
    'user' : UserInputHandler, 
    'restate' : RestateInputHandler,
    'ablation-observation' : AblationToolHandler,
    'evaluator' : EvaluatorHandler,
    'ablation-evaluator' : AblationEvaluatorHandler,
}
