import logging
import re, inspect
from typing import List
import torch
import torch.nn as nn
from transformers.generation.configuration_utils import GenerationConfig
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers import PreTrainedModel

from .config import Config
from .constants import *
from .monitor import *
from .utils import *

"""
The rules for monitors are
    1) can only be in one state at a time
    2) any user-input state is also a state that the user can cancel from
"""

logger = logging.getLogger(LOGGER_NAME)

###
#
###

_INPUT_SN = 'input'

class DecodingMonitor:

    def __init__(self, 
        orig_text : str,
        input_text : str, 
        specification : tuple, 
        config : Config, 
        generation_config : GenerationConfig, 
        model : PreTrainedModel, 
        tokenizer : PreTrainedTokenizer, 
        hard_stop_sequences : List[str],
        **kwargs):
        self.device = ('cuda' if torch.cuda.is_available() else 'cpu')

        self.model = model
        self.tokenizer = tokenizer

        self.input_text = input_text
        self._monitor_name = specification[1]

        self._hard_stop_sequences = []
        self._hard_stop_sequences.extend([tokenizer.convert_tokens_to_ids(tokenizer.tokenize(stop_seq.strip())) for stop_seq in hard_stop_sequences])
        self._hard_stop_sequences.extend([tokenizer.convert_tokens_to_ids(tokenizer.tokenize('\n' + stop_seq.strip())) for stop_seq in hard_stop_sequences])
        self._soft_stop_sequences = []
        newline = tokenizer.convert_tokens_to_ids(tokenizer.tokenize('\n'))
        if newline != tokenizer.convert_tokens_to_ids(tokenizer.tokenize(' ')):
            self._soft_stop_sequences.append(newline)

        # for ablations, things are a lot less... structured...
        if 'ablation' in config.agent_type: self._soft_stop_sequences.clear()

        self._transition_monitor = TransitionMonitor(specification, tokenizer)

        self.input_text = self.input_text.strip()
        init_state = self._transition_monitor.get_valid_states()[0]
        self.states = self._transition_monitor.states
        self._transition_monitor.step(init_state)
        self.max_new_tokens, self.max_state_length = generation_config.max_new_tokens, config.max_state_tokens
        self._tgen_ct = 0

        """
        This next part is taken from the initialization of .generate()
        """
        inputs = self.tokenizer(input_text, return_tensors='pt', truncation=True).to(self.device)['input_ids']

        generation_config = copy.deepcopy(generation_config)
        model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be model kwargs
        generation_config.validate()
        self.model._validate_model_kwargs(model_kwargs.copy())

        # 3. Define model inputs
        # inputs_tensor has to be defined
        # model_input_name is defined if model-specific keyword input is passed
        # otherwise model_input_name is None
        # all model-specific keyword inputs are removed from `model_kwargs`
        inputs_tensor, model_input_name, model_kwargs = self.model._prepare_model_inputs(
            inputs, generation_config.bos_token_id, model_kwargs
        )
        batch_size = inputs_tensor.shape[0]

        # 4. Define other model kwargs
        model_kwargs["output_attentions"] = generation_config.output_attentions
        model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
        # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
        # generating the first new token or not, and we only want to use the embeddings for the first new token)
        if not self.model.config.is_encoder_decoder and model_input_name == "inputs_embeds":
            model_kwargs["use_cache"] = True
        else:
            model_kwargs["use_cache"] = generation_config.use_cache

        accepts_attention_mask = "attention_mask" in set(inspect.signature(self.model.forward).parameters.keys())
        requires_attention_mask = "encoder_outputs" not in model_kwargs

        if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
            model_kwargs["attention_mask"] = self.model._prepare_attention_mask_for_generation(
                inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
            )

        # decoder-only models should use left-padding for generation
        if not self.model.config.is_encoder_decoder:
            # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
            # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
            if (
                generation_config.pad_token_id is not None
                and len(inputs_tensor.shape) == 2
                and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
            ):
                logger.warning(
                    "A decoder-only architecture is being used, but right-padding was detected! For correct "
                    "generation results, please set `padding_side='left'` when initializing the tokenizer."
                )

        if self.model.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
            # if model is encoder decoder encoder_outputs are created
            # and added to `model_kwargs`
            model_kwargs = self.model._prepare_encoder_decoder_kwargs_for_generation(
                inputs_tensor, model_kwargs, model_input_name
            )

        # 5. Prepare `input_ids` which will be used for auto-regressive generation
        if self.model.config.is_encoder_decoder:
            self.input_ids, model_kwargs = self.model._prepare_decoder_input_ids_for_generation(
                batch_size=batch_size,
                model_input_name=model_input_name,
                model_kwargs=model_kwargs,
                decoder_start_token_id=generation_config.decoder_start_token_id,
                bos_token_id=generation_config.bos_token_id,
                device=inputs_tensor.device,
            )
        else:
            self.input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")

        self.model_kwargs = model_kwargs

        self.pad_token_id = self.model.generation_config.pad_token_id
        self.eos_token_id = self.model.generation_config.eos_token_id

        """
        This is back to our code
        """
        self.generation_config = generation_config
        self.seq_dims = []

        # initialize history
        self._init_index = self.input_ids.size(1)
        self.history = [(_INPUT_SN, orig_text)]
        self._add_text('\n' + init_state.text)
        self.in_progress = (init_state, self.input_ids.size(1))

    ###
    #
    ###

    @staticmethod
    def load_spec_from_config(config : Config):
        with open(config.agent_file, 'r') as f: lines = list(f.readlines())
        lines = [line.strip().split('#')[0] for line in lines]
        lines = [line for line in lines if line]
        content = '\n'.join(lines)
        spec = parse_specification_to_tuple(content)
        return spec

    ###
    #
    ###

    def get_final_response(self):
        return self.tokenizer.decode(self.input_ids[0, self._init_index:], skip_special_tokens=True).strip()

    def submit_input(self, text : str):

        text = text.strip()

        tokenized = self.tokenizer(' ' + text, return_tensors='pt')['input_ids']
        self._force_decoder(tokenized)
        self._process_state_change(force_change=True)

    def _add_text(self, text : str):
        tokenized = self.tokenizer(text, return_tensors='pt')['input_ids']
        assert tokenized.size(0) == 1
        if tokenized[0, -1] == self.eos_token_id: tokenized = tokenized[:, :-1]
        self._force_decoder(tokenized)

    def get_current_state(self):
        return self.in_progress[0]

    def in_env_response_state(self):
        return self.get_current_state().defer_to_env

    ###
    # Generate
    ###

    def generate(self):

        self._sgen_ct = 0

        # auto-regressive generation
        while self._tgen_ct < self.max_new_tokens:

            target_state, _ = self.in_progress

            force_change = False
            if target_state.constraints:
                logger.debug(f'Constraints found for state [{target_state.name}], will decode to one of specified options')
                self._disjunctive_decode(target_state.constraints)
                force_change = True
            else:
                self._generation_step()
                self._sgen_ct += 1

            # now we do state validation
            return_to_user, await_reply = self._process_state_change(force_change=force_change)
            if return_to_user: return await_reply

        return False

    def _generation_step(self, forced_dec : List[int]=None):
        assert forced_dec is None or type(forced_dec) == list, 'Forced decoder input invalid'

        # prepare model inputs
        model_inputs = self.model.prepare_inputs_for_generation(self.input_ids, **self.model_kwargs)

        # forward pass to get next token
        outputs = self.model(
            **model_inputs,
            return_dict=True,
        )

        if type(forced_dec) == list and len(forced_dec) == 1:
            next_tokens = torch.tensor(forced_dec).to(self.input_ids.device)
        else:
            # sample
            next_token_logits = outputs.logits[:, -1, :]

            if type(forced_dec) == list:
                logit_mask = torch.zeros_like(next_token_logits) + float('-inf')
                for dec_id in forced_dec: logit_mask[:, dec_id] = 0.0
                next_token_logits += logit_mask

            if self.eos_token_id is not None and not self._transition_monitor.exit_reached(): 
                next_token_logits[:, self.eos_token_id] = float('-inf')

            if self.generation_config.do_sample:
                probs = nn.functional.softmax(next_token_logits, dim=-1)
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
            else:
                next_tokens = torch.argmax(next_token_logits, dim=-1)

        # update generated ids, model inputs, and length for next step
        self.input_ids = torch.cat([self.input_ids, next_tokens[:, None]], dim=-1)
        self.model_kwargs = self.model._update_model_kwargs_for_generation(
            outputs, self.model_kwargs, is_encoder_decoder=self.model.config.is_encoder_decoder
        )

        self._tgen_ct += 1

    def _hard_stop_too_early(self):
        if not self._transition_monitor.exit_reached():
            matched_seq = seq_matches_any(self._hard_stop_sequences, self.input_ids)
            if matched_seq is not None:
                logger.debug(f'Hard stop sequence matched in a non-terminating state, rolling back prediction')
                self._roll_back_decoder(matched_seq)
                return True
        return False

    def _end_by_stop_seq(self):
        if self._transition_monitor.exit_reached():
            # hard stop
            if seq_matches_any(self._hard_stop_sequences, self.input_ids) is not None:
                logger.debug(f'Hard stop sequence matched, terminating!')
                return True

            # soft stop
            if seq_matches_any(self._soft_stop_sequences, self.input_ids) is not None:
                logger.debug(f'Soft stop sequence matched while in exit state, terminating!')
                return True

    def _process_state_change(self, force_change : bool=False):

        def _update_history(next_state=None):
            prior_state, prior_start_ind = self.in_progress
            history_string = self.tokenizer.decode(self.input_ids[0, prior_start_ind:], skip_special_tokens=True)
            if next_state is not None: history_string = history_string[:-len(next_state.text)]
            self.history.append((prior_state, history_string.strip()))
            logger.debug(f'Adding state [{self.history[-1][0].name}] to history with response \"{self.history[-1][1]}\"')

        force_change = force_change or self._hard_stop_too_early()

        if self._end_by_stop_seq():
            _update_history()
            return True, False

        proposed_state, matched_seq = self._transition_monitor.matches_state(self.input_ids)

        if force_change or self._sgen_ct >= self.max_state_length or proposed_state is not None:

            _update_history(proposed_state)

            # reset state token generation count
            self._sgen_ct = 0

            if proposed_state is not None:
                if (not self._transition_monitor.exit_reached()) and self._transition_monitor.accept_state(proposed_state):
                    self._transition_monitor.step(proposed_state)
                    self.in_progress = (proposed_state, self.input_ids.size(1))

                    logger.debug(f'Current state is now [{proposed_state.name}]')

                    return proposed_state.defer_to_env, proposed_state.defer_to_env
                else:
                    # roll back decoder predictions
                    logger.debug(f'Predicted state [{proposed_state.name}] violates specification, rolling back prediction')
                    self._roll_back_decoder(matched_seq)
                    if self._transition_monitor.exit_reached():
                        return True, False

            possible_states = self._transition_monitor.get_valid_states(incl_op=True)

            if possible_states:

                logger.debug(f'Possible target states are [{", ".join([ps.name for ps in possible_states])}]')

                proposed_state = self._decode_to_state(possible_states)

                # update decoder to have chosen that state
                self._transition_monitor.step(proposed_state)

                self.in_progress = (proposed_state, self.input_ids.size(1))

                return proposed_state.defer_to_env, proposed_state.defer_to_env
            else:
                logger.debug(f'Exit reached!')
                assert self._transition_monitor.exit_reached()
                return True, False

        return False, False

    def _decode_to_state(self, possible_states : List[DecodingState]):
        disjunctive_alternatives = [ps.tokens for ps in possible_states]

        if len(possible_states) == 1: logger.debug(f'Constraining decoding to [{possible_states[0]}]')
        else: logger.debug(f'Constraining decoding to one of [{", ".join([ps.name for ps in possible_states])}]')

        self._disjunctive_decode(disjunctive_alternatives)
        proposed_state, _ = self._transition_monitor.matches_state(self.input_ids)
        return proposed_state
    
    def _disjunctive_decode(self, disj_alt : List[List[int]]):
        i = 0
        while i < max([len(x) for x in disj_alt]):
            tokens = [tok_lst[i] for tok_lst in disj_alt]
            self._generation_step(tokens)
            chosen_token = self.input_ids[:, -1][0]
            disj_alt = [tok_lst for tok_lst in disj_alt if tok_lst[i] == chosen_token]
            i += 1

    def _roll_back_decoder(self, matched_seq : torch.Tensor):
        if 'attention_mask' in self.model_kwargs and self.model_kwargs['attention_mask'] is not None and not self.model.config.is_encoder_decoder:
            self.model_kwargs['attention_mask'] = self.model_kwargs['attention_mask'][..., :-len(matched_seq)]

        self.input_ids = self.input_ids[..., :-len(matched_seq)]

        if 'past_key_values' in self.model_kwargs and self.model_kwargs['past_key_values'] is not None:
            new_pkv = []
            if not self.seq_dims: self.seq_dims = [[None, None] for _ in range(len(self.model_kwargs['past_key_values']))]
            seq_len = self.input_ids.size(-1) + len(matched_seq) - 1
            for i in range(len(self.model_kwargs['past_key_values'])):
                new_pkv.append([])
                for j in range(len(self.model_kwargs['past_key_values'][i])):
                    if j > 1: 
                        assert self.model.config.is_encoder_decoder
                        new_pkv[-1].append(self.model_kwargs['past_key_values'][i][j])
                    else:
                        if self.seq_dims[i][j] is None:
                            self.seq_dims[i][j] = next(iter([k for k in range(1, len(self.model_kwargs['past_key_values'][i][j].size()))
                                                             if self.model_kwargs['past_key_values'][i][j].size(k) == seq_len]))
                        if self.seq_dims[i][j] == 1: new_pkv[-1].append(self.model_kwargs['past_key_values'][i][j][:, :-len(matched_seq)])
                        elif self.seq_dims[i][j] == 2: new_pkv[-1].append(self.model_kwargs['past_key_values'][i][j][:, :, :-len(matched_seq)])
                        elif self.seq_dims[i][j] == 3: new_pkv[-1].append(self.model_kwargs['past_key_values'][i][j][:, :, :, :-len(matched_seq)])
            self.model_kwargs['past_key_values'] = new_pkv

    def _force_decoder(self, new_tokens : torch.Tensor):
        if len(new_tokens.size()) > 1: new_tokens = new_tokens[0]
        for tok in new_tokens: self._generation_step([int(tok)])

    ###
    # validate
    ###

    def validate_sequence_from_string(self, string : str):
        return self._transition_monitor.validate_sequence_from_string(string)

    def validate_sequence(self, state_history : List[Tuple[str, str]]):
        return self._transition_monitor.validate_sequence(state_history)