import logging
import torch
import collections
import numpy as np

from transformers import AutoTokenizer, AutoModel

from prompt_compiler.lark_utils import bnf2lark, larkstr2rulelist, skipped_nonterminal_names, rulelist2bnfstr

logger = logging.getLogger("global_logger")

sb_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/paraphrase-MiniLM-L6-v2')
sb_model = AutoModel.from_pretrained('sentence-transformers/paraphrase-MiniLM-L6-v2')
sb_model.eval()

def score_by_sentencebert(prediction, candidate):
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    encoded_input = sb_tokenizer([prediction, candidate], padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        model_output = sb_model(**encoded_input)
        sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
        score = torch.cosine_similarity(sentence_embeddings[0], sentence_embeddings[1], dim=0)
        return score.item()

def predict_program_with_earley_correction(llm, prompt, parser, max_tokens, llm_cache_dir, freq_penalty, seed):
    MAX_NUM_CORRECTION = 20
    num_correction_left = MAX_NUM_CORRECTION

    def validate_program(prediction):
        try:
            parser.parse(prediction)
            return True
        except Exception as runtime_e:
            logger.info(f"Error in prediction: {prediction}")
            logger.info(f"Error: {str(runtime_e)}")
            return False

    def obtain_correction_pairs(prediction):
        """
        Returns a list of candidates in the form of (prefix, suffix).
        """
        try:
            parser.parse(prediction)
            return []
        except Exception as runtime_e:
            return parser.handle_error(runtime_e)

    partial_program_prediction = ""
    ret_prediction, initial_prediction = None, None
    while num_correction_left > 0:
        if partial_program_prediction == "":
            _prompt = prompt
        else:
            _prompt = prompt + "```DSL\n" + partial_program_prediction + "\n```\nPlease continue generating DSL after the provided portion, without repeating it."
        response = llm.greedy_completion(_prompt, stop_token="\n\n", max_tokens=max_tokens, llm_cache_dir=llm_cache_dir, freq_penalty=freq_penalty, seed=seed)
        residual_program_prediction = response.response_text

        # if the prediction is empty, return the initial prediction
        if initial_prediction is None:
            initial_prediction = residual_program_prediction
        program_prediction = partial_program_prediction + residual_program_prediction

        if validate_program(program_prediction):
            ret_prediction = program_prediction
            break

        # find the max score from a list of score
        pair = obtain_correction_pairs(program_prediction)
        assert len(pair) > 0, "no correction pairs found"
        logger.info(f"prefix [{pair[0]}] suffix [{pair[1]}]")
        scores = []
        prefix = pair[0]
        suffixs = list(pair[1])
        for suffix in suffixs:
            # no longer supported due to API change
            # _prompt = prompt + prefix
            # score = llm.evaluate_completion(_prompt, suffix, average=True)
            candidate = prefix + suffix
            score = score_by_sentencebert(program_prediction, candidate)
            scores.append(score)

        best_idx = scores.index(max(scores))
        fixed_prediction = prefix + suffixs[best_idx]
        logger.info(f"fixed prediction: {fixed_prediction}")

        if validate_program(fixed_prediction):
            ret_prediction = fixed_prediction
            break

        partial_program_prediction = fixed_prediction
        num_correction_left -= 1

    if ret_prediction is None:
        logger.info(f"cannot find a valid prediction after {MAX_NUM_CORRECTION} retries")
        ret_prediction = initial_prediction
        if not validate_program(ret_prediction):
            ret_prediction = ""

    return ret_prediction


def predict_rules_with_earley_correction(llm, prompt, ruleset, delimiter, max_tokens, llm_cache_dir, freq_penalty, seed, use_action_list_flag):
    """
    Predict grammar rules with earley correction.
    Args:
        delimiter: the separator between rule and program
    """
    MAX_NUM_CORRECTION = 10
    CANDIDATE_NUM_THRESHOLD = 16
    num_correction_left = MAX_NUM_CORRECTION

    rules_by_origin = collections.defaultdict(list)
    for rule in ruleset:
        rules_by_origin[rule.origin].append(rule)

    def validate_rule(prediction):
        pred_lark = bnf2lark(prediction)
        pred_rulelist = larkstr2rulelist(pred_lark)  # an ordered list of rules
        for pred_rule in pred_rulelist:
            if ((not use_action_list_flag and pred_rule not in ruleset and pred_rule.origin not in skipped_nonterminal_names+("action_name", )) or
                    (use_action_list_flag and pred_rule not in ruleset and pred_rule.origin not in skipped_nonterminal_names)):
                logger.debug(f"found an invalid rule: {pred_rule}")
                return False, f"found an invalid rule: {pred_rule}"
        return True, ""

    def filter_candidates(pred_rule, candidates):
        if len(candidates) > CANDIDATE_NUM_THRESHOLD:
            scores = []
            for candidate in candidates:
                pred_rulename = str(pred_rule)
                score = score_by_sentencebert(pred_rulename, str(candidate))
                scores.append(score)
            top_candidates = [candidates[i] for i in np.argsort(scores)[-CANDIDATE_NUM_THRESHOLD:]]
            candidates = top_candidates

        return candidates

    def obtain_correction_pairs(prediction):
        """
        Returns a list of candidates in the form of (prefix, suffix).
        """
        pred_lark = bnf2lark(prediction)
        pred_rulelist = larkstr2rulelist(pred_lark)  # an ordered list of rules

        lhs_set = set()
        partial_rule_list = []
        error_rule_list = []
        for pred_rule in pred_rulelist:
            if ((not use_action_list_flag and pred_rule not in ruleset and pred_rule.origin not in skipped_nonterminal_names+("action_name", )) or
                    (use_action_list_flag and pred_rule not in ruleset and pred_rule.origin not in skipped_nonterminal_names)):
                if pred_rule not in error_rule_list:
                    error_rule_list.append(pred_rule)

            else:
                # avoid duplicate rules
                if pred_rule not in partial_rule_list:
                    if pred_rule.origin not in lhs_set:
                        lhs_set.add(pred_rule.origin)
                    partial_rule_list.append(pred_rule)

        for pred_rule in error_rule_list:
            # find condidates considering the origin of the rule
            if pred_rule.origin in lhs_set:
                candidates = [r for r in rules_by_origin[pred_rule.origin] if r not in partial_rule_list]
            else:
                candidates = [r for r in ruleset if r not in partial_rule_list]
            candidates = filter_candidates(pred_rule, candidates)

            logger.info(f"number of candidates for correction: {len(candidates)}")
            for candidate_idx, candidate in enumerate(candidates):
                logger.debug(f"candidate {candidate_idx}: [{candidate}]")

            # serialize the partial rule list
            ret_pairs = []
            complete_right_list = [x for x in partial_rule_list if x.origin != pred_rule.origin]
            worng_prefix_list = [x for x in partial_rule_list if x.origin == pred_rule.origin]
            prefix = rulelist2bnfstr(complete_right_list + worng_prefix_list)
            for candidate in candidates:
                first_rhs = candidate.origin not in lhs_set
                if first_rhs:
                    suffix = "\n" + candidate.to_bnf()
                else:
                    suffix = " | " + ' '.join(candidate.expansion)
                ret_pairs.append((prefix, suffix))
            if len(ret_pairs) == 0:
                return [(prefix, "")]
            else:
                return ret_pairs


    partial_rule_prediction = ""
    error_info = ""
    ret_prediction, initial_prediction = None, None

    while num_correction_left > 0:
        if partial_rule_prediction == "":
            _prompt = prompt + partial_rule_prediction
        else:
            _prompt = prompt + "```BNF\n" + partial_rule_prediction + "\n```\nPlease regenerate BNF grammar (A non-terminal can appear on the multiple production rules' left side).\nYour previous output had the following error:\n" + error_info
        response = llm.greedy_completion(_prompt, stop_token="\n\n",
                                         max_tokens=max_tokens, llm_cache_dir=llm_cache_dir, freq_penalty=freq_penalty, seed=seed)
        residual_rule_prediction = response.response_text.split(delimiter)[0].replace("BNF:", "")

        if initial_prediction is None:
            initial_prediction = residual_rule_prediction
        rule_prediction = partial_rule_prediction + residual_rule_prediction
        pred_lark = bnf2lark(rule_prediction)
        pred_rulelist = larkstr2rulelist(pred_lark)
        pred_rulelist = [x for x in list(set(pred_rulelist)) if x.expansion != ("ESCAPED_STRING", )]
        rule_prediction = rulelist2bnfstr(pred_rulelist)
        logger.debug(f"partial rule prediction: {rule_prediction}")

        flag, info = validate_rule(rule_prediction)
        if flag:
            ret_prediction = rule_prediction
            break
        error_info = info

        pairs = obtain_correction_pairs(rule_prediction)
        # assert len(pairs) > 0, "no correction pairs found"
        logger.debug(f"number of candidates: {len(pairs)}")

        scores = []
        for prefix, suffix in pairs:
            # no longer supported due to API change
            # _prompt = prompt + prefix
            # score = llm.evaluate_completion(_prompt, suffix, average=True)

            candidate = prefix + suffix
            score = score_by_sentencebert(rule_prediction, candidate)

            scores.append(score)
        best_idx = scores.index(max(scores))
        fixed_rule_prediction = pairs[best_idx][0] + pairs[best_idx][1]
        logger.debug(f"fixed rule: {pairs[best_idx][1]}")
        logger.debug(f"fixed partial rule prediction:\n{fixed_rule_prediction}")

        partial_rule_prediction = fixed_rule_prediction
        num_correction_left -= 1

    if ret_prediction is None:
        logger.warning(f"cannot find a valid rule prediction after {MAX_NUM_CORRECTION} retries")
        ret_prediction = initial_prediction

    if "ESCAPED_STRING" in ret_prediction:
        return "%import common.ESCAPED_STRING\n" + ret_prediction
    else:
        return ret_prediction