import logging
from typing import Dict, Union
import torch

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
from peft import PeftModel

from transformers import GenerationConfig
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING

from .constrained_decoding import DecodingMonitor
from .prompt import Prompt
from .state_handlers import STATE_HANDLERS
from .config import Config
from .utils import *
from .states import *

logger = logging.getLogger(LOGGER_NAME)

###
#
###

class AgentOutput:

    def __init__(self, text : str, history : List):
        self.text = text
        self.history = history

###
#
###

class LLMAgent:

    START_TOKENS = ['Question:', 'Claim:']

    def __init__(self, config : Config, custom_state_handlers : Dict=None):

        self.config = config
        self.prompt_obj = Prompt(config)

        self.device = ('cuda' if torch.cuda.is_available() else 'cpu')
        self.generation_config = self._get_gen_params()
        self.model = self._get_auto_model()
        self.tokenizer = AutoTokenizer.from_pretrained(config.model_type, truncation_side='left')
        self.specification = DecodingMonitor.load_spec_from_config(config)
        self.states = extract_states(self.specification, self.tokenizer)
        self.custom_state_handlers = custom_state_handlers
        self._reset_state_handlers()

    def _reset_state_handlers(self):
        self.state_handlers = { state.name : STATE_HANDLERS[state.name](self.config) for state in self.states.values() if state.defer_to_env }
        if self.custom_state_handlers is not None:
            for state_name, state_handler in self.custom_state_handlers.items():
                self.state_handlers[state_name] = state_handler(self.config)

    def predict(self, input_text : str, **kw_args):

        def _env_response():
            end_generation = False
            while monitor.in_env_response_state():
                response, end_generation = self._get_response(monitor)
                if not end_generation: monitor.submit_input(response)
            return end_generation

        self._reset_state_handlers()

        with torch.no_grad():

            logger.debug(f'Beginning prediction for input \"{input_text}\"')

            monitor = self._init_monitor(input_text, **kw_args)

            # first we check for whether the first state is a env-response state
            end_generation = _env_response()

            # loop until complete
            while not end_generation:
                awaiting_response = monitor.generate()
                # either generation is complete OR it is waiting on a response from the user
                end_generation = _env_response() if awaiting_response else True
            final_resp = monitor.get_final_response()

            history = [((state if type(state) == str else state.name), resp) for state, resp in monitor.history]

            return AgentOutput(final_resp, history)

    def _init_prompt(self, input_text : str, **kw_args):
        # now add state-specific inclusions
        prompt_kw_args = dict(input=input_text.strip(), **kw_args)
        for state_handler in self.state_handlers.values():
            state_handler.adjust_prompt_args(prompt_kw_args)
        for k in prompt_kw_args: prompt_kw_args[k] = str(prompt_kw_args[k])

        filled_text = self.prompt_obj.fill_prompt(**prompt_kw_args)
        
        return filled_text

    def _init_monitor(self, input_text : str, **kw_args):
        filled_text = self._init_prompt(input_text, **kw_args)
        stop_sequences = [self.prompt_obj.bop_text]
        monitor = DecodingMonitor(input_text, filled_text, self.specification, self.config, self.generation_config, self.model, self.tokenizer, stop_sequences)
        for state_handler in self.state_handlers.values():
            state_handler.adjust_monitor(monitor)
        
        # validate that prompt examples match decoding monitor specification
        if 'ablation' not in self.config.agent_type:
            for ex in self.prompt_obj.examples:
                monitor.validate_sequence_from_string(ex)

        return monitor

    def _get_response(self, monitor : DecodingMonitor):
        curr_state = monitor.get_current_state()
        state_handler_response = self.state_handlers[curr_state.name](monitor)
        assert type(state_handler_response) == str or type(state_handler_response) in [tuple, list], 'State handler response must either be a string or a pair of string and boolean!'
        if type(state_handler_response) in [tuple, list]:
            assert len(state_handler_response) == 2 and type(state_handler_response[0]) == str and type(state_handler_response[1]) == bool, 'Malformed state handler response'
            response, end_generation = state_handler_response
        else: response, end_generation = state_handler_response, False
        return response, end_generation

    def _get_gen_params(self):
        return GenerationConfig(
            temperature=self.config.temperature,
            output_scores=True,
            return_dict_in_generate=True,
            output_attentions=True,
            max_new_tokens=self.config.max_tokens,
            do_sample=self.config.temperature > 0,
            early_stopping=True,
        )

    def _get_auto_model(self):
        assert not any([restr_model in self.config.model_type for restr_model in ['llama']]), f'Not allowed to download {self.config.model_type} locally!'

        kw_args = dict(trust_remote_code=True)

        auto_config = AutoConfig.from_pretrained(self.config.model_type, **kw_args)
        
        if type(auto_config) in MODEL_FOR_CAUSAL_LM_MAPPING or any([out_model in self.config.model_type for out_model in ['falcon', 'mosaic']]):
            auto_model_type = AutoModelForCausalLM
        elif type(auto_config) in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING:
            auto_model_type = AutoModelForSeq2SeqLM
        else: raise ValueError(f'Unhandled model type {self.config.model_type}')
        
        if self.config.model_path:
            model = self.load_local_model(auto_model_type)
        else:
            model = auto_model_type.from_pretrained(self.config.model_type, device_map='auto', load_in_8bit=True, torch_dtype=torch.bfloat16, config=auto_config, **kw_args)

        model.eval()

        return model

    def load_local_model(self, auto_model_type : Union[AutoModelForCausalLM, AutoModelForSeq2SeqLM], qlora=True): #TODO: do the same for Seq2Seq

        logger.info(f'Loading a local model from {self.config.model_path}')

        adapters_name = self.config.model_path
        model = auto_model_type.from_pretrained(
            self.config.model_type,
            load_in_4bit=True,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            max_memory={i: '24000MB' for i in range(torch.cuda.device_count())},
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type='nf4'
            ),
            trust_remote_code=True
        )
        if qlora:
            model = PeftModel.from_pretrained(model, adapters_name)
        return model