# Copyright (c) 2019-2021, Alibaba Group. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Classes and functions related to evaluate a model.
"""
import logging
import numpy as np
from typing import List, Union, Optional
from collections.abc import Iterable 
from sympy import sympify, simplify, symbols
import re
import json
from dataclasses import dataclass
from typing import Dict, List
from collections import OrderedDict
from fractions import Fraction
from copy import deepcopy
import itertools
import evaluate
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.trainer_utils import EvalPrediction

logger = logging.getLogger(__name__)


class Seq2SeqMetricsOnSeqIDs:
    """ Calculate the difference between a query sequence and a reference sequence.
    The two sequences are presented as input ids. Otherwise, simply use a function, e.g., 

    bleu = evaluate.load("bleu")
    results = bleu.compute(predictions=predictions, references=references)
    
    Supported metrics (case insensitive):
        BLEU (Bilingual Evaluation Understudy)
        ROUGE, or Recall-Oriented Understudy for Gisting Evaluation
        Accuracy

    Args:
        metrics (`Union[List[str], str]`, default to `accuracy`):
            A list of names of metrics to calculate.
        tokenizer (`Optional[PreTrainedTokenizer]`):
            The tokenizer used by model. Must provide for metrics including `bleu` and `rouge`.
        padding_side (`str`, default to `left`):
            In the input ids, whether the [pad] tokens are on the left side or right.
    """
    SUPPORT_METRICS = ['accuracy', 'bleu', 'rouge', 'formula_accuracy']

    def __init__(
        self, 
        metrics: Union[List[str], str] = 'accuracy', 
        tokenizer: Optional[PreTrainedTokenizer] = None,
        padding_side: str = 'left',
        ) -> None:
        # check metrics
        if isinstance(metrics, str):
            metrics = [metrics.lower()]
        else:
            metrics = [metric.lower() for metric in metrics]
            assert all(metric in self.SUPPORT_METRICS for metric in metrics), ValueError(
                'Unsupported metric(s): {}'.format([metric for metric in metrics if metric not in self.SUPPORT_METRICS]))
        self.metrics = dict.fromkeys(metrics)
        for metric in self.metrics:
            if metric == 'formula_accuracy':
                self.metrics[metric] = EvaluateMetricsOnMathFormula(strict=False)
            else:
                self.metrics[metric] = evaluate.load(metric)
        # check tokenizer
        self.tokenizer = tokenizer
        if any(metric in metrics for metric in ['bleu', 'rouge','formula_accuracy']):
            assert self.tokenizer is not None, RuntimeError(
                'To calculate metrics {}, tokenizer must be provided.'.format(metrics))
        # check padding_side
        self.padding_side = padding_side
        assert self.padding_side in ['left', 'right'], ValueError(
            'padding_side can only be left or right, got {}'.format(self.padding_side))
        
    def flatten_metric_list(self, results):
        # if results have list, replace that list with its elements separately.
        metrics_to_remove = []
        metrics_to_add = {}
        for metric, value in results.items():
            if isinstance(value, list):
                if len(value) < 8:  # for short list, we add them separately
                    metrics_to_remove.append(metric)
                    metrics_to_add.update({(metric + '_{}'.format(index)): _value for index, _value in enumerate(value)})
                else:  # two many elements make the results messy, we get the mean values
                    logger.warning("metric {} is a list of too many elements, thus we get the mean of it.")
                    results[metric] = sum(value) / len(value)
        if metrics_to_remove:
            for metric in metrics_to_remove:
                results.pop(metric)
            results.update(metrics_to_add)

        return results
        
    def compute_metrics(self, eval_pred: Optional[EvalPrediction]) -> dict:
        prediction_ids, label_ids = eval_pred[0], eval_pred[1]
        if prediction_ids.ndim == 3:
            prediction_ids = np.argmax(prediction_ids, axis=-1)
        valid_indices = label_ids != -100
        if any(metric in self.metrics for metric in ['bleu', 'rouge', 'formula_accuracy']):
            pred_seqs = [self.tokenizer.decode(pred[index], skip_special_tokens=True) for pred, index in zip(prediction_ids, valid_indices)]
            ref_seqs = [self.tokenizer.decode(label_id[index], skip_special_tokens=True) for label_id, index in zip(label_ids, valid_indices)]
            # remove empty strings
            is_seq_valid = np.array(
                [(len(pred_seq) > 0 and len(ref_seq) > 0) for pred_seq, ref_seq in zip(pred_seqs, ref_seqs)], dtype=bool)
            num_valid_seqs = is_seq_valid.sum()
            num_all_seqs = len(pred_seqs)
            if num_valid_seqs < num_all_seqs:
                if num_valid_seqs > 0:
                    logger.warning("Among {} sequences to be evaluated, {} are empty and removed.".format(
                        num_all_seqs, num_all_seqs - num_valid_seqs))
                    prediction_ids = prediction_ids[is_seq_valid]
                    label_ids = label_ids[is_seq_valid]
                    valid_indices = label_ids != -100
                    pred_seqs = [pred_seq for pred_seq, is_valid in zip(pred_seqs, is_seq_valid) if is_valid]
                    ref_seqs = [ref_seq for ref_seq, is_valid in zip(ref_seqs, is_seq_valid) if is_valid]
                else:
                    raise ValueError('All {} sequences are empty, evaluation metric cannot be calculated.'.format(num_all_seqs))
        
        results = []
        for metric, evaluator in self.metrics.items():
            if metric == 'accuracy':
                valid_indices_pred = np.concatenate((valid_indices[:,1:], valid_indices[:,0:1]), axis=1)
                result = evaluator.compute(predictions=prediction_ids[valid_indices_pred], references=label_ids[valid_indices])
            elif metric == 'bleu':
                result = evaluator.compute(predictions=pred_seqs, references=ref_seqs)
            elif metric == 'rouge':
                result = evaluator.compute(
                    predictions=pred_seqs, references=ref_seqs, tokenizer=self.tokenizer.tokenize)
            elif metric == 'formula_accuracy':
                result = evaluator.compute(predictions=pred_seqs, references=ref_seqs)
            else:
                raise NotImplementedError
            # append metric to all keys in result so that different metrics do not mix up
            result = {(_metric if metric in _metric else metric + '_' + _metric): _value for _metric, _value in result.items()}
            results.append(result)
        # merge results        
        results = {metric:value for result in results for metric, value in result.items()}
        # if results have list, replace that list with its elements separately.
        results = self.flatten_metric_list(results)
                
        return results
    
    def __call__(self, eval_pred: Optional[EvalPrediction]) -> dict:
        return self.compute_metrics(eval_pred)


class Seq2SeqMetricsOnGenerationSeqIDs(Seq2SeqMetricsOnSeqIDs):
    """ Calculate the difference between a query sequence and a reference sequence.
    The two sequences are presented as input ids. Otherwise, simply use a function, e.g., 

    bleu = evaluate.load("bleu")
    results = bleu.compute(predictions=predictions, references=references)
    
    Supported metrics (case insensitive):
        BLEU (Bilingual Evaluation Understudy)
        ROUGE, or Recall-Oriented Understudy for Gisting Evaluation
        Accuracy

    Args:
        metrics (`Union[List[str], str]`, default to `accuracy`):
            A list of names of metrics to calculate.
        tokenizer (`Optional[PreTrainedTokenizer]`):
            The tokenizer used by model. Must provide for metrics including `bleu` and `rouge`.
        padding_side (`str`, default to `left`):
            In the input ids, whether the [pad] tokens are on the left side or right.
    """
        
    def compute_metrics(self, eval_pred: Optional[EvalPrediction]) -> dict:
        prediction_ids, label_ids = eval_pred[0], eval_pred[1]
        if prediction_ids.ndim == 3:
            prediction_ids = np.argmax(prediction_ids, axis=-1)
        pred_valid_indices = prediction_ids != -100
        valid_indices = label_ids != -100
        if any(metric in self.metrics for metric in ['bleu', 'rouge', 'formula_accuracy']):
            pred_seqs = [self.tokenizer.decode(pred[index], skip_special_tokens=True) for pred, index in zip(prediction_ids, pred_valid_indices)]
            ref_seqs = [self.tokenizer.decode(label_id[index], skip_special_tokens=True) for label_id, index in zip(label_ids, valid_indices)]
            # remove empty strings
            is_seq_valid = np.array(
                [(len(pred_seq) > 0 and len(ref_seq) > 0) for pred_seq, ref_seq in zip(pred_seqs, ref_seqs)], dtype=bool)
            num_valid_seqs = is_seq_valid.sum()
            num_all_seqs = len(pred_seqs)
            if num_valid_seqs < num_all_seqs:
                if num_valid_seqs > 0:
                    logger.warning("Among {} sequences to be evaluated, {} are empty and removed.".format(
                        num_all_seqs, num_all_seqs - num_valid_seqs))
                    prediction_ids = prediction_ids[is_seq_valid]
                    label_ids = label_ids[is_seq_valid]
                    pred_valid_indices = prediction_ids != -100
                    valid_indices = label_ids != -100
                    pred_seqs = [pred_seq for pred_seq, is_valid in zip(pred_seqs, is_seq_valid) if is_valid]
                    ref_seqs = [ref_seq for ref_seq, is_valid in zip(ref_seqs, is_seq_valid) if is_valid]
                else:
                    raise ValueError('All {} sequences are empty, evaluation metric cannot be calculated.'.format(num_all_seqs))
        
        results = []
        for metric, evaluator in self.metrics.items():
            if metric == 'accuracy':
                result = evaluator.compute(predictions=prediction_ids[pred_valid_indices], references=label_ids[valid_indices])
            elif metric == 'bleu':
                result = evaluator.compute(predictions=pred_seqs, references=ref_seqs)
            elif metric == 'rouge':
                result = evaluator.compute(
                    predictions=pred_seqs, references=ref_seqs, tokenizer=self.tokenizer.tokenize)
            elif metric == 'formula_accuracy':
                result = evaluator.compute(predictions=pred_seqs, references=ref_seqs)
            else:
                raise NotImplementedError
            # append metric to all keys in result so that different metrics do not mix up
            result = {(_metric if metric in _metric else metric + '_' + _metric): _value for _metric, _value in result.items()}
            results.append(result)
        # merge results        
        results = {metric:value for result in results for metric, value in result.items()}
        # if results have list, replace that list with its elements separately.
        results = self.flatten_metric_list(results)
        return results


@dataclass
class ObjectiveDeclaration:
    # minimize/maximize 
    direction: str
    formula: str


class MetricsOnMathFormula:

    nump = r'\-?(?:\d+(?:\.\d+|\/\d+)?)'
    varp = r'(?:[a-zA-Z](?:[a-zA-Z]|\d|_)*)'
    token = f'(?:(?:{nump}\*{varp})|(?:{nump}{varp})|(?:{varp}))'
    l = r'\('
    r = r'\)'

    exp_sp = f'(?:(?:(?:{nump}\*?{varp})|{varp}|{nump})(?:\s*[+-]\s*(?:(?:{nump}\*?{varp})|{varp}|{nump}))+)'
    exp_part = f'(?:(?:\({exp_sp}\))|(?:\({nump}\)))'

    # exp_mul = f'(?:(?:{exp_part}|{token}|{nump})(?:(?:\*?{exp_part})|(?:\*(?:{token}|{nump})))+)'
    exp_mul_1 = f'(?:{exp_part}(?:\*?(?:{exp_part}|{token}|{nump}))+)'
    exp_mul_2 = f'(?:(?:{token}|{nump})(?:(?:\*?{exp_part})|(?:\*(?:{token}|{nump})))+)'
    exp_mul = f'(?:{exp_mul_1}|{exp_mul_2})'

    exp_div_1 = f'(?:(?:{exp_part}|\(?{exp_mul}\)?|{token})\/(?:{exp_part}|\(?{exp_mul}\)?|{token}|{nump}))'
    exp_div_2 = f'(?:{nump}\/(?:{exp_part}|\(?{exp_mul}\)?|{varp}))'
    exp_div = f'(?:{exp_div_1}|{exp_div_2})'

    exp_mul_cp1 = f'(?:(?:(?:\(?{exp_div}\)?)|{token}|{nump})(?:(?:\*?(?:\({exp_div}\)))|(?:\*{exp_div})|(?:\*?{exp_part})|(?:\*(?:{token}|{nump})))+)'
    exp_mul_cp2 = f'(?:{exp_part}(?:\*?(?:(?:\(?{exp_div}\)?)|{exp_part}|{token}|{nump}))+)'
    exp_mul_cp = f'(?:{exp_mul_cp1}|{exp_mul_cp2})'
    
    exp_cp = f'(?:(?:{l}?{exp_mul_cp}{r}?|{l}?{exp_div}{r}?|{exp_part}|{l}?{token}{r}?|{nump})(?:[+-](?:{l}?{exp_mul_cp}{r}?|{l}?{exp_div}{r}?|{exp_part}|{l}?{token}{r}?|{nump}))+)'
    exp_all = f'(?:(?:{l}?{exp_mul_cp}{r}?|{l}?{exp_div}{r}?|{exp_part}|{l}?{token}{r}?|{nump})(?:[+-](?:{l}?{exp_mul_cp}{r}?|{l}?{exp_div}{r}?|{exp_part}|{l}?{token}{r}?|{nump}))*)'
    
    # mul_1 = f'({exp_part})({exp_part}|{token}|{nump})'
    # mul_2 = f'({token}|{nump})({exp_part})'
    mul_1 = r'(\))([^\s+\/+-><=])'
    mul_2 = r'([^\s+\/+-=><])(\()'
    mul_1 = r'(\))([a-zA-Z]|\d|\()'
    mul_2 = r'([a-zA-Z]|\d|\))(\()'
    
    div_1f = f'(?:({exp_part}|(?:\(?{exp_mul}\)?)|{token})(\/)({exp_part}|(?:\(?{exp_mul}\)?)|{token}|{nump})(<=|>=|=)({exp_all}))'
    div_2f = f'(?:({nump})(\/)({exp_part}|(?:\(?{exp_mul}\)?)|{varp})(<=|>=|=)({exp_all}))'

    multil = f'({varp}(?:,{varp})+)(<=|>=|=)({exp_all})'
    multil2 = f'({exp_all})(<=|>=|=)({varp}(?:,{varp})*)(<=|>=|=)({exp_all})'
    multil3 = f'({exp_all})(<=|>=|=)({exp_all}(?:,{exp_all})*)(<=|>=|=)({exp_all})'

    # explain = r'\(\s*[\w\-\']+(?:(?:\s)[\w\-\']+)*\s*\)'
    explain = r'\(\s*[\w\']+((\s)[\w\']+)+\s*\)'
    explain_ch= r'[(（]([^\(\)]*[一-龥][^\(\)]*)+[)）]'


    def __init__(self, strict=True, double_check_threshold=4):
        self.strict = strict
        self.double_check_threshold = double_check_threshold
        A= symbols('A')
        B= symbols('B')
        C= symbols('C')
        D= symbols('D')
        E= symbols('E')
        F= symbols('F')
        G= symbols('G')
        H= symbols('H')
        I= symbols('I')
        J= symbols('J')
        K= symbols('K')
        L= symbols('L')
        M= symbols('M')
        N= symbols('N')
        O= symbols('O')
        P= symbols('P')
        Q= symbols('Q')
        R= symbols('R')
        S= symbols('S')
        T= symbols('T')
        U= symbols('U')
        V= symbols('V')
        W= symbols('W')
        X= symbols('X')
        Y= symbols('Y')
        Z= symbols('Z')
        self.d = {'A': A, 'B': B, 'C': C, 'D': D, 'E': E, 'F': F, 'G': G, 'H': H, 'I': I, 'J': J, 'K': K, 'L': L, 'M': M, \
            'N': N, 'O': O, 'P': P, 'Q': Q, 'R': R, 'S': S, 'T': T, 'U': U, 'V': V, 'W': W, 'X': X, 'Y': Y, 'Z': Z}
    
    def parse_formula(self, formula):
        if not self.strict: 
            formula = re.sub(r'\s+','', formula)
            formula = re.sub('≤', '<=', formula)
            formula = re.sub('≥', '>=', formula)
            formula = re.sub('==', '=', formula)

            formula = re.sub('%', '*0.01', formula)
            formula = re.sub('\$', '', formula)
            formula = re.sub(r'(?<=\d),(\d\d\d)', r'\g<1>',formula)

            mul_pattern_1 = re.compile(r'(?<=[^a-zA-Z_\d])(-?\d+(?:\.\d+|\/\d+)?)([a-zA-Z](?:[a-zA-Z]|\d|_)*)')
            mul_pattern_2 = re.compile(r'^(-?\d+(?:\.\d+|\/\d+)?)([a-zA-Z](?:[a-zA-Z]|\d|_)*)')
            formula = re.sub(mul_pattern_1, r'\g<1>*\g<2>', formula)
            formula = re.sub(mul_pattern_2, r'\g<1>*\g<2>', formula)
            formula = re.sub(self.mul_1, r'\g<1>*\g<2>', formula)
            formula = re.sub(self.mul_2, r'\g<1>*\g<2>', formula)

        if re.match(self.div_1f,formula):
            numerator,div, denominator, operator, limit = re.findall(self.div_1f,formula)[0]
            formula = f'{numerator}{operator}({limit})*{denominator}'
        
        if re.match(self.div_2f,formula):
            numerator,div, denominator, operator, limit = re.findall(self.div_2f,formula)[0]
            formula = f'{numerator}{operator}({limit})*{denominator}'
        
        if re.match(self.multil3, formula):
            lvalue, lopt, vars, ropt, rvalue = re.findall(self.multil3, formula)[0]
            vars = vars.split(',')
            formula = [f'{lvalue} {lopt} {var}' for var in vars]
            formula = formula + [f'{var} {ropt} {rvalue}' for var in vars]
        elif re.match(self.multil2, formula):
            lvalue, lopt, vars, ropt, rvalue = re.findall(self.multil2, formula)[0]
            vars = vars.split(',')
            formula = [f'{lvalue} {lopt} {var}' for var in vars]
            formula = formula + [f'{var} {ropt} {rvalue}' for var in vars]
        elif re.match(self.multil, formula):
            vars, opt,value = re.findall(self.multil, formula)[0]
            vars = vars.split(',')
            formula = [f'{var} {opt} {value}' for var in vars]
        else:
            formula = formula

        formulas = formula if type(formula) == list else [formula]

        formulas_copy = []
        for formula in formulas:
            equal = r'(\b\s*=\s*\b)|((?<=\))=\b)|(\b=(?=\())'
            if re.search(equal,formula):
                formulas_copy = formulas_copy+[re.sub(equal, ' >= ', formula), re.sub(equal, ' <= ', formula)]
        
        if len(formulas_copy)>len(formulas):
            formulas = formulas_copy

        return formulas
    
    def parse_obj(self, obj):
        obj_direction = ''
        if re.search(r'Minimize|minimize', obj):
            obj_direction = 'Minimize'
        if re.search(r'Maximize|maximize', obj):
            obj_direction = 'Maximize'
        formula = re.split(',',obj)[-1]
        formula = re.sub(r'\s+','',formula)
        formula = self.parse_formula(formula)

        return ObjectiveDeclaration(direction=obj_direction,formula=formula[0])

    def parse_cons(self, cons):
        parsed_cons = []
        for con in cons:
            con = re.sub(r'\s+','',con)
            parsed_cons = parsed_cons + self.parse_formula(con)
        return parsed_cons

    def evaluate_objective(self,pred_obj, true_obj):
        flag = 1
        if pred_obj.direction != true_obj.direction:
            flag = -1

        return simplify(sympify(f'{str(flag)}*({pred_obj.formula})', self.d), rational=True).equals(simplify(sympify(true_obj.formula, self.d), rational=True))
      
    def evaluate_cons(self, parsed_pred_cons, parsed_true_cons):
        net_pred_cons = self.del_redundant_cons([parsed_pred_cons])[0][0]
        net_true_cons = self.del_redundant_cons([parsed_true_cons])[0][0]
        pred_cons_sym = [simplify(sympify(pred_con, self.d), rational=True) for pred_con in net_pred_cons]
        true_cons_sym = [simplify(sympify(true_con, self.d), rational=True) for true_con in net_true_cons]
        i = 0
        while i < len(true_cons_sym):
            j = 0
            while j < len(pred_cons_sym):
                if true_cons_sym[i].equals(pred_cons_sym[j]):
                    true_cons_sym.pop(i)
                    pred_cons_sym.pop(j)
                    i-=1
                    break
                j+=1
            i+=1
        
        if len(true_cons_sym) > 0 or len(pred_cons_sym) > 0:
            return False
            
        return True
    
    def del_redundant_cons(self, conses):
        nump = r'(\d+(?:\.\d+|\/\d+)*)'
        varp = r'([a-zA-Z]+\d*)'
        exp1 = f'^{varp}(>=){nump}$'
        exp2 = f'^{varp}(<=){nump}$'
        exp1_1 = f'[^*\/+-]{varp}(>=){nump}[^*\/+-]'
        exp1_2 = f'^{varp}(>=){nump}[^*\/+-]'
        exp2_1 = f'[^*\/+-]{varp}(<=){nump}[^*\/+-]'
        exp2_2 = f'^{varp}(<=){nump}[^*\/+-]'

        parsed_conses = []
        cons_maps = []
        for cons in conses:
            cons = [re.sub(r'\s+','',con) for con in cons]
            l_exps, s_exps, l_exp, s_exp, lcons, scons = [],[],[],[],[],[]
            for con in cons:
                le = re.match(exp1, con)
                if le:
                    l_exp.append(le.groups())
                    l_exps.append(con)
                se = re.match(exp2, con)
                if se:
                    s_exp.append(se.groups())
                    s_exps.append(con)
            if l_exp is not None:
                lbound_map = {}
                for le in l_exp: 
                    lvar, opt, lbound = le
                    if lvar not in lbound_map:
                        lbound_map[lvar] = []
                    lbound_map[lvar].append(lbound)
                for var, bounds in lbound_map.items():
                    bounds_value = [eval(bound) for bound in bounds]
                    max_index = bounds_value.index(max(bounds_value))
                    lcon = f'{var}>={bounds[max_index]}'
                    lcons.append(lcon)
            if s_exp is not None:
                sbound_map = {}
                for se in s_exp: 
                    svar, _, sbound = se
                    if svar not in sbound_map:
                        sbound_map[svar] = []
                    sbound_map[svar].append(sbound)
                for var, bounds in sbound_map.items():
                    bounds_value = [eval(bound) for bound in bounds]
                    min_index = bounds_value.index(min(bounds_value))
                    scon = f'{var}<={bounds[min_index]}'
                    scons.append(scon)
            cons_net = [con for con in cons if con not in l_exps and  con not in s_exps]
            cons = cons_net + lcons + scons
            cons = [re.sub(r'\s*(\+|\-|\*|\/|<=|>=|==|>|<|=)\s*', r' \g<1> ', con) for con in cons]
            cons = [re.sub(r'(\d)(\s\/\s)(\d)', r'\g<1>/\g<3>', con) for con in cons]
            parsed_conses.append(cons)
            cons_map = {nc : [nc] for nc in cons_net}
            for lc in lcons:
                var = re.findall(varp,lc)[0]
                cons_map[lc] = [f'{var}>={lb}' for lb in lbound_map[var]]
            for sc in scons:
                var = re.findall(varp,sc)[0]
                cons_map[sc] = [f'{var}<={sb}' for sb in sbound_map[var]]
            cons_maps.append(cons_map)
        parsed_cons_maps = []
        for cons_map in cons_maps:
            parsed_cons_map = {}
            for k, vs in cons_map.items():
                k = re.sub(r'\s*(\+|\-|\*|\/|<=|>=|==|>|<|=)\s*', r' \g<1> ', k)
                k = re.sub(r'(\d)(\s\/\s)(\d)', r'\g<1>/\g<3>',k)
                vs = [re.sub(r'\s*(\+|\-|\*|\/|<=|>=|==|>|<|=)\s*', r' \g<1> ', con) for con in vs]
                vs = [re.sub(r'(\d)(\s\/\s)(\d)', r'\g<1>/\g<3>', con) for con in vs]
                parsed_cons_map[k] = vs
            parsed_cons_maps.append(parsed_cons_map)

        return parsed_conses, parsed_cons_maps

    def align_obj_cons(self, true_obj, pred_obj, pred_cons):
        flag = 1
        if pred_obj.direction != true_obj.direction:
            flag = -1
        
        aligned_pred_formula = simplify(sympify(f'{str(flag)}*({pred_obj.formula})', self.d), rational=True)
        true_obj_formula = simplify(sympify(true_obj.formula, self.d), rational=True)
        aligned_pred_formula_str = re.sub(r'\s+','',str(aligned_pred_formula))
        # token_pattern = r'(?:(-?\d+(?:\.\d+|\/\d+)?)\s*\*\s*([a-zA-Z](?:[a-zA-Z]|\d|_)*))|([a-zA-Z](?:[a-zA-Z]|\d|_)*)'
        # pred_token_matches = re.findall(token_pattern, aligned_pred_formula_str)
        # true_token_matches = re.findall(token_pattern, true_obj.formula)

        def get_tokens(token_matches):
            tokens = OrderedDict()
            for match in token_matches:
                if match[2]=='':
                    tokens[match[1]] = tokens.get(match[1],[])+[match[0]]
                else:
                    tokens[match[2]] = tokens.get(match[2],[])+['1']

            for k,v in tokens.items():
                tokens[k] = simplify(sympify('+'.join(v),self.d))

            return tokens
        
        # pred_tokens = get_tokens(pred_token_matches)
        # true_tokens = get_tokens(true_token_matches)
        pred_tokens = aligned_pred_formula.as_coefficients_dict()
        true_tokens = true_obj_formula.as_coefficients_dict()

        fail_flag = False
        aligned_obj = deepcopy(pred_obj)
        aligned_cs = deepcopy(pred_cons)
        if pred_tokens != true_tokens:
            if len(pred_tokens.keys())!= len(true_tokens.keys()):
                fail_flag = True
            else:
                visit = OrderedDict()
                mapping = {}
                for k in true_tokens.keys(): 
                    visit[k] = False
                for pk in pred_tokens:
                    for tk in true_tokens:
                        if visit[tk] == False and (abs(pred_tokens[pk] - true_tokens[tk]) < 1e-5):
                            mapping[str(pk)] = str(tk)
                            visit[tk] = True
                            break
               
                if len(mapping)>0:
                    p = r"\b|\b".join(mapping.keys())
                    p = r'\b'+p+r'\b'
                    pattern = re.compile(p)
                    aligned_obj.formula= pattern.sub(lambda x: mapping[x.group()], aligned_obj.formula)
                    for i in range(len(aligned_cs)):
                        aligned_cs[i] =  pattern.sub(lambda x: mapping[x.group()], aligned_cs[i])
        
        rlt = {
            'fail_flag' : fail_flag,
            'aligned_obj' : aligned_obj,
            'aligned_cs' : aligned_cs,
            'pred_tokens' : pred_tokens,
            'true_tokens' : true_tokens
        }

        return rlt
        
    def get_same_symbols(self, token_map):
        pkeys = list(token_map.keys())
        lpk = len(token_map)
        same_symbols = []
        for i in range(lpk):
            for j in range(i+1,lpk):
                if token_map[pkeys[i]] == token_map[pkeys[j]]:
                    if pkeys[i] not in same_symbols:
                        same_symbols.append(pkeys[i])
                    if pkeys[j] not in same_symbols:
                        same_symbols.append(pkeys[j])
        return same_symbols

    def double_check(self, same_symbols, pred_cs, true_cons):
        double_check = False
        if len(same_symbols) > 1:
            ssyb = tuple(same_symbols)
            for order in itertools.permutations(same_symbols):
                pred_cs_copy = deepcopy(pred_cs)
                if order == ssyb:
                    continue
                mapping = {}
                for s, o in zip(ssyb,order):
                    mapping[str(s)] = str(o)
                p = r"\b|\b".join(mapping.keys())
                p = r'\b'+p+r'\b'
                for i in range(len(pred_cs_copy)):
                    pattern = re.compile(p)
                    pred_cs_copy[i]= pattern.sub(lambda x: mapping[x.group()], pred_cs_copy[i])
                if self.evaluate_cons(pred_cs_copy, true_cons):
                    double_check = True
                    break
        return double_check
        
    def evaluate_formula(self, pred_obj, true_obj, pred_cons, true_cons):
        try:
            if pred_obj=='' or pred_cons == []:
                raise Exception(f'pred_obj=='' or pred_cons == []')
            if true_obj=='' or true_cons == []:
                raise Exception(f'true_obj=='' or true_cons == []')

            error_flag = False

            pred_objd = self.parse_obj(pred_obj)
            true_objd = self.parse_obj(true_obj)
           
            parsed_pred_cons = self.parse_cons(pred_cons)
            parsed_true_cons =self.parse_cons(true_cons)

            rlt = self.align_obj_cons(true_objd,pred_objd,parsed_pred_cons)
            aligned_cs = rlt['aligned_cs']
            aligned_obj = rlt['aligned_obj']

            if rlt['fail_flag']:
                obj_flag = False
            else:
                obj_flag = self.evaluate_objective(aligned_obj,true_objd)
            cons_flag = self.evaluate_cons(aligned_cs, parsed_true_cons)
            
            true_tokens = rlt['true_tokens']

            # if objectve flag is true but cons_flag is false and
            # the number of variables with same coefficent is less than 
            # self.double_check_threshold then do double check.
            if obj_flag and not cons_flag:
                same_symbols = self.get_same_symbols(deepcopy(true_tokens))
                if len(same_symbols) > 1 and len(same_symbols) <= self.double_check_threshold:
                    cons_flag = self.double_check(same_symbols, deepcopy(aligned_cs), deepcopy(parsed_true_cons))

        except:
            error_flag = True
            if not ('obj_flag' in locals()):
                obj_flag = False
            if not ('cons_flag' in locals()):
                cons_flag = False
            if not ('pred_obj' in locals()):
                pred_obj = ''
            if not ('pred_cons' in locals()):
                pred_cons = []
        
        eval_info = {
            "obj_flag" : obj_flag,
            "cons_flag" : cons_flag,
            "true_obj" : true_obj,
            "pred_obj" : pred_obj,
            "parsed_true_cons" : true_cons,
            "parsed_pred_cons" : pred_cons
        }

        return error_flag, obj_flag and cons_flag , eval_info

    def compute_metric(self, pred_objs, true_objs, pred_conses, true_conses, index_list=None):

        error_logs = []
        right_logs = []
        index = 0
        true_num = 0
        error_num = 0

        num_instances_processed = 0
        num_all_instances = min(len(pred_objs), len(pred_conses),len(true_objs), len(true_conses))
        if index_list == None:
            index_list = [i for i in range(0, num_all_instances)]
        assert isinstance(index_list, list), TypeError(
            "Expect index_list to be list; got {}".format(type(index_list)))
        for index in index_list:
            pred_obj, true_obj, pred_cons, true_cons = pred_objs[index], true_objs[index], pred_conses[index], true_conses[index]
            error_flag, all_flag, eval_info = self.evaluate_formula(pred_obj, true_obj, pred_cons, true_cons)

            error_num += int(error_flag)
            if not all_flag:
                error_log = {
                    'error' : error_flag,
                    'index' : index,
                    'eval_info' : eval_info
                    }
                error_logs.append(error_log)
            else :
                true_num +=1
                right_log = {
                    'index' : index,
                    'eval_info' : eval_info
                    }
                right_logs.append(right_log)

            num_instances_processed += 1
            if num_instances_processed % 50 == 0:
                logger.info('{}/{} instances processed.'.format(num_instances_processed, num_all_instances))
        
        acc= true_num/len(index_list)

        return acc, error_logs, right_logs
    
    def extract_obj(self, answer):
        pred_obj_pattern = r'(?:The objective is|目标)[:：]?\s*\b(Minimize|Maximize|minimize|maximize|Min|Max|min|max):?\s*(.*?)(?:(?=(?:约束|The constraints|(?:s\.t\.?)|(?:[Ss]ubject to)))|(?:(?:\.[^0-9])|。|\n))'
        pred_obj_match = re.findall(pred_obj_pattern, answer)
        obj_match = pred_obj_match[0]
    
        obj_dir = obj_match[0]
        obj_fm = obj_match[1]
        obj_fm = re.sub(self.explain_ch, '',obj_fm)
        obj_fm = re.sub(self.explain, '',obj_fm)
        index = obj_fm.find('=')
        if index != -1 :
            obj_fm = obj_fm[index+1:]
        obj = re.sub(r'\s+','',(obj_dir + ',' + obj_fm))
        
        # str
        return obj

    def parse_true_obj(self, true_obj):
        true_obj_pattern = r'^(Minimize|Maximize|minimize|maximize|Min|Max|min|max):?\s*(.*)'
        true_obj_match = re.findall(true_obj_pattern, true_obj)
        obj_match = true_obj_match[0]
        obj_dir = obj_match[0]
        obj_fm = obj_match[1]
        obj = re.sub(r'\s+','',(obj_dir + ',' + obj_fm))
        
        # str
        return obj
    
    def extract_cons(self,answer):
        cons_pattern = r'(?:(?:The constraints are:?|约束条件[：:]|(?:s\.t\.?)|(?:[Ss]ubject to))\s*,?\s*(.*?)(?:(?:\.[^0-9]|。)))|(?:(?:The constraints are:?|约束条件[：:]|(?:s\.t\.?)|(?:[Ss]ubject to))\s*,?\s*(.*)\.$)|(?:(?:The constraints are:?|约束条件[：:]|(?:s\.t\.?)|(?:[Ss]ubject to))\s*,?\s*(.*)$)'
        cons_match = re.findall(cons_pattern,answer)[0]
        for cons in cons_match:
            if cons != '':
                cons_match = cons.strip()
                break

        cons_match = re.sub(r'(?<=\d),(\d\d\d)', r'\g<1>',cons_match)
        cons_match = re.sub(self.explain_ch, '',cons_match)
        cons_match = re.sub(self.explain, '',cons_match)
        cons = re.split(r'[,;]\s+and|and|\s*[,，]\s*|;|；',cons_match)
        cons = [re.sub(r'\s+','',c) for c in cons]
        for con in cons:
            if re.search(r'[一-龥]{2,}',con):
                cons.remove(con)
        return cons

    def parse_formula_from_answer(self, answer):
        try:
            if re.search(r'[一-龥]', answer):
                if answer[-1] == '，':
                    answer = answer[:-1]
                if answer[-1] != '。':
                    answer = answer + '。'
            else:
                if answer[-1] == ',':
                    answer = answer[:-1]
                if answer[-1] != '.':
                    answer = answer + '.'

            pred_obj = self.extract_obj(answer)
            pred_cons = self.extract_cons(answer)
        except:
            if not ('pred_obj' in locals()):
                pred_obj = ''
            if not ('pred_cons' in locals()):
                pred_cons = []

        return pred_obj, pred_cons


class EvaluateMetricsOnMathFormula(MetricsOnMathFormula):

    def compute(self, predictions, references):
        pred_objs, pred_conses, true_objs, true_conses = [], [], [], []
        for pred_answer, true_answer in zip(predictions, references):
            pred_obj, pred_cons = self.parse_formula_from_answer(pred_answer)
            pred_objs.append(pred_obj)
            pred_conses.append(pred_cons)
            true_obj, true_cons = self.parse_formula_from_answer(true_answer)
            true_objs.append(true_obj)
            true_conses.append(true_cons)
        acc, _, _ = self.compute_metric(pred_objs, true_objs, pred_conses, true_conses)
        return {'formula_accuracy' : acc}

        
class SimpleVotingEnsemble:
    """ This class conducts simple voting for a list of predictions. "predictions" is a list of list of dict
    with the following format:
    {
        ... 
        "index": 0,
        "output": "The variables are: .... Define them as: ....\nThe objective is: ....\nThe constraints are: ...."
        ...
    }
    or
    {
        ... 
        "index": 0,
        "output": "变量：...。分别定义为：...。\n目标：...。\n约束条件：...。"
        ...
    }
    Alternatively, "predictions" can also be a list of dict, with "output" value a list of strings.

    voted_predictions is returned.
    """
    def __init__(self, threshold=None):
        self.threshold = threshold

    def detect_language(self, text):
        if re.search(u'[\u4e00-\u9fff]', text):
            return "cn"
        else:
            return "en"

    def parse_instance(self, instance: dict):
        # 变量：餐桌数量，椅子数量。分别定义为：x，y。\n目标：maximize 350 * x + 75 * y。\n约束条件：y >= 0.7 * (x + y)，8 * x + 2 * y <= 500，1000 * x + 150 * y <= 20000，x >= 0，y >= 0。
        # The variables are: number of rickshaws used, number of ox carts used. Define them as: x, y.\nThe objective is: maximize 50 * x + 30 * y.\nThe constraints are: 10 * x + 8 * y <= 200, x <= y, x >= 0, y >= 0.
        output = instance['output']
        if isinstance(output, list):
            instance = [self.parse_instance({
                "index": instance["index"], 
                "instruction": instance["instruction"], 
                "input": "", 
                "output": output_
                }) for output_ in output]
        elif isinstance(output, str):
            language = self.detect_language(output)
            if language == 'cn':
                output = output.replace(', ', '，').replace(',', '，').replace(': ', "：").replace(':', "：")
                var_des = re.search(r'变量：?(.*?)。分别定义为：?', output).group(1).split('，')
                var_def = re.search(r'分别定义为：?(.*?)。\n目标：?', output).group(1).split('，')
                objective = re.search(r'目标：?(.*?)。\n约束条件：?', output).group(1)
                constraints = re.search(r'约束条件：?(.*?)。', output).group(1).split('，')
            else:
                output = output.replace("“", '"').replace("”", '"').replace("，", ', ').replace("：", ': ').replace("。", '. ').replace("  ", ' ')
                var_des = re.search(r'The variables are:? (.*?)\. Define them as:? ', output).group(1)
                if ', ' in var_des:
                    var_des = var_des.split(', ')
                elif ' and ' in var_des:
                    var_des = var_des.split(' and ')
                else:
                    raise RuntimeError("Cannot parse: {}".format(var_des))
                var_def = re.search(r'Define them as:? (.*?)\.\nThe objective is:?', output).group(1)
                if ', ' in var_def:
                    var_def = var_def.split(', ')
                elif ' and ' in var_def:
                    var_def = var_def.split(' and ')
                else:
                    raise RuntimeError("Cannot parse: {}".format(var_def))
                objective = re.search(r'The objective is:? (.*?)\.\nThe constraints are:?', output).group(1)
                constraints = re.search(r'The constraints are:? (.*?)$', output).group(1).split(', ')
                constraints[-1] = constraints[-1].rstrip('.')

            instance = {
                "index": instance['index'],
                "instruction": instance['instruction'],
                "var_description": dict(zip(var_def, var_des)),
                "objective_description": ["", objective],
                "constraint_description": constraints,
                "language": language
            }
        else:
            raise TypeError("instance output {} cannot be parsed.".format(output))
        return instance

    @staticmethod
    def max_of_list(some_list):
        """ Return the maximum of a list along with its index
        """
        index_std, maximum = max(list(enumerate(some_list)), key=lambda x: x[1])
        return index_std, maximum

    def vote_language(self, languages):
        all_language_set = set(languages)
        if len(all_language_set) == 1:
            language = languages[0]
        else:
            all_language_set = list(all_language_set)
            occurrences = [languages.count(language) for language in all_language_set]
            index_std, _ = self.max_of_list(occurrences)
            language = all_language_set[index_std]

        return language

    @staticmethod
    def LCSubStr(str1: Union[str, List[str]], str2: Union[str, List[str]]):
        """find common substring using dynamic programming
        """
        if not str1:
            return str1
        if not str2:
            return str2

        N = len(str1)
        M = len(str2)

        if N >= M and str2 in str1:
            return str2
        elif N < M and str1 in str2:
            return str1

        LCSuff = [[0 for k in range(M+1)] for l in range(N+1)]
        mx = 0
        ending_indices = [0, 0]
        common_str = ""
        for i in range(N + 1):
            for j in range(M + 1):
                if (i == 0 or j == 0):
                    LCSuff[i][j] = 0
                elif (str1[i-1] == str2[j-1]):
                    LCSuff[i][j] = LCSuff[i-1][j-1] + 1
                    if LCSuff[i][j] > mx:
                        ending_indices = [i, j]
                    mx = max(mx, LCSuff[i][j])
                else:
                    LCSuff[i][j] = 0
        if mx > 0:
            common_str = str1[ending_indices[0]-mx:ending_indices[0]]
        return common_str

    def check_char(self, char):
        ascii_value = ord(char)
        if ascii_value >= 19968 and ascii_value <= 40959:
            return 'cn'
        elif (ascii_value >= 65 and ascii_value <= 90) or (ascii_value >= 97 and ascii_value <= 122):
            return 'en'
        elif ascii_value == 32:
            return 'space'
        elif ascii_value == 10:
            return 'newline'
        elif ascii_value >= 48 and ascii_value <= 57:
            return 'num'
        elif ascii_value in [65288, 65289, 8220, 8221]:
            return 'punct'
        else:
            raise RuntimeError("Cannot determine language of {}.".format(char))

    def _segment_cn_text(self, text: str):
        segments = []
        chunk = ''
        for char in text:
            lang_char = self.check_char(char)
            if lang_char in ['cn', 'num', 'punct']:
                if chunk:
                    segments.append(chunk)
                    chunk = ''
                segments.append(char)
            elif lang_char == 'en':
                chunk += char
            elif lang_char in {'space', 'newline'}:
                if chunk:
                    segments.append(chunk)
                    chunk = ''

        return segments

    def _join_cn_text(self, segments: list):
        text = ''
        is_last_en = False
        for segment in segments:
            if len(segment) == 1:
                lang_char = self.check_char(segment)
                if lang_char == 'cn':
                    text += segment
                    is_last_en = False
                else:
                    text += ' ' + segment if is_last_en else segment
                    is_last_en = True
            else:
                text += ' ' + segment if is_last_en else segment
                is_last_en = True

        return text
        
    def get_diff_substr_cn(self, var_des_out):
        assert len(var_des_out) >= 2
        common_substr = self.LCSubStr(
            self._segment_cn_text(var_des_out[0]), self._segment_cn_text(var_des_out[1]))
        for var_des_out_i in var_des_out[2:]:
            common_substr = self.LCSubStr(common_substr, self._segment_cn_text(var_des_out_i))
        if common_substr:
            common_substr = self._join_cn_text(common_substr)
            var_des_out = [var_des_out_i.replace(common_substr, "") for var_des_out_i in var_des_out]
            return self.get_diff_substr_cn(var_des_out)
        return var_des_out

    def get_diff_substr_en(self, var_des_out):
        assert len(var_des_out) >= 2
        common_substr = self.LCSubStr(var_des_out[0].split(' '), var_des_out[1].split(' '))
        for var_des_out_i in var_des_out[2:]:
            common_substr = self.LCSubStr(common_substr, var_des_out_i.split(' '))
        if common_substr and common_substr[0]:
            common_substr = ' '.join(common_substr)
            var_des_out = [var_des_out_i.replace(common_substr, "").replace("  ", " ").strip() for var_des_out_i in var_des_out]
            return self.get_diff_substr_en(var_des_out)
        return var_des_out

    def get_diff_substr(self, var_des_out, language):
        if language == 'cn':
            var_des_out = [self._join_cn_text(self._segment_cn_text(var_des_out_i)) for var_des_out_i in var_des_out]
            var_des_out = self.get_diff_substr_cn(var_des_out)
        else:
            var_des_out = self.get_diff_substr_en(var_des_out)
        
        return var_des_out

    @staticmethod
    def str_in_str(var_des0, var_des1):
        cn_words = re.findall(u'[\u4e00-\u9fff]', var_des0)
        en_words = [word for word in re.sub(u'[\u4e00-\u9fff]', '', var_des0).split(" ") if word]
        words_in_des0 = cn_words + en_words

        return all([word in var_des1 for word in words_in_des0])

    def vote_variables(self, variables, objectives, nested_constraints, languages, instance_index):
        # variables: list of dict
        # objectives: list of str
        # nested_constraints: list of list of str

        # 0. get all info
        num_experts = len(variables)
        all_var_def = [list(variable.keys()) for variable in variables]
        all_var_des = [list(variable.values()) for variable in variables]
        all_var_tags = [self.get_diff_substr(
            var_des, languages[index]) for index, var_des in enumerate(all_var_des)]

        # 1. get most common variable definition (including number of variables)
        all_var_def_join = ["".join(var_def) for var_def in all_var_def]  # list is unhashable, so convert it to str
        all_var_def_join_set = list(set(all_var_def_join))
        if len(all_var_def_join_set) == 1:
            index_std = 0
            maximum = len(variables)
        else:
            occurrences = [all_var_def_join.count(var_def_join) for var_def_join in all_var_def_join_set]
            index_std, maximum = self.max_of_list(occurrences)
            var_def_join_std = all_var_def_join_set[index_std]
            index_std = next(i for i, v in enumerate(all_var_def_join) if v == var_def_join_std)

        # 2. check most common variable description
        if maximum > 1:
            candidate_indices = [index for index, var_def_join in enumerate(all_var_def_join) \
                if var_def_join == all_var_def_join[index_std]]
            all_var_des_join = ["".join(all_var_des[index]) for index in candidate_indices]
            all_var_des_join_set = list(set(all_var_des_join))
            if len(all_var_des_join_set) == 1:
                index_std = candidate_indices[0]
            else:
                occurrences = [all_var_des_join.count(var_des_join) for var_des_join in all_var_des_join_set]
                index_std2, maximum = self.max_of_list(occurrences)
                var_des_join_std = all_var_des_join_set[index_std2]
                index_std2 = next(i for i, v in enumerate(all_var_des_join) if v == var_des_join_std)
                index_std = candidate_indices[index_std2]

        # 3. get standard variable definition, description and tags
        variable_standard = variables[index_std]
        var_def_std = all_var_def[index_std]
        var_des_std = all_var_des[index_std]
        var_tags_std = all_var_tags[index_std]

        # 4. check from variable description if variable orders can be matched
        for index in range(num_experts):
            if index != index_std and (
                all_var_def[index] != all_var_def[index_std]
                or all_var_tags[index] != all_var_tags[index_std]
            ):
                mapping = {}

                # check if we can infer an order by comparing 
                # all_var_tags[index] vs var_tags_std
                is_matches = [[self.str_in_str(var_tag, var_des) for var_des in var_des_std] for var_tag in all_var_tags[index]]
                is_matches_std = [[self.str_in_str(var_tag, var_des) for var_des in all_var_des[index]] for var_tag in var_tags_std]
                is_matches = np.logical_or(np.array(is_matches, dtype=bool), np.array(is_matches_std, dtype=bool).T)

                if all(is_matches.sum(axis=1) == np.ones(len(all_var_tags[index]))):
                    # for each variable, it has a unique mapping in variable_standard
                    indices0, indices1 = is_matches.nonzero()
                    for ind0, ind1 in zip(indices0.tolist(), indices1.tolist()):
                        if all_var_def[index][ind0] != all_var_def[index_std][ind1]:
                            mapping.update({all_var_def[index][ind0]: all_var_def[index_std][ind1]})
                else:
                    print("ID {}: cannot infer mapping from variable names: {} vs {}".format(
                        instance_index, all_var_des[index], var_des_std))

                if mapping:  # do mapping of all variables at the same time
                    pattern = re.compile(r"(\b" + r"\b)|(\b".join(mapping.keys()) + r"\b)")
                    objectives[index] = pattern.sub(lambda match: mapping[match.group(0)], objectives[index])
                    nested_constraints[index] = [
                        pattern.sub(lambda match: mapping[match.group(0)], const) for const in nested_constraints[index]
                    ]

        return variable_standard, objectives, nested_constraints

    def vote_objective(self, objectives):
        # objectives, list of str
        objective_set = list(set(objectives))
        if len(objective_set) == 1:
            return objective_set[0]
        else:
            occurrences = [objectives.count(objective) for objective in objective_set]
            index_std, maximum = self.max_of_list(occurrences)
            return objective_set[index_std]

    def vote_constraints(self, constraints):
        # constraints, list of list of str
        num_predictions = len(constraints)
        major_threshold = self.threshold or (num_predictions // 2 + num_predictions % 2)

        constraint_set = list(set([const_ for const in constraints for const_ in const]))
        occurrences = []
        for const in constraint_set:
            occurrences.append(sum([1 for constraint in constraints if const in constraint]))
        voted_const = [const for const, occur in zip(constraint_set, occurrences) if occur >= major_threshold]

        return voted_const

    def __call__(self, predictions):
        assert isinstance(predictions, list) and isinstance(predictions[0], (list, dict)), TypeError(
            "Expect predictions to be a list of list/dict, got {} and {}.".format(
                type(predictions), type(predictions[0]) if isinstance(predictions, Iterable) else "not iterable"
            )
        )

        voted_predictions = []
        for instances in zip(*predictions):
            instances = [self.parse_instance(instance) for instance in instances]
            if len(instances) == 1 and isinstance(instances[0], list):
                instances = instances[0]
            index = instances[0]['index']
            variables = [instance['var_description'] for instance in instances]
            objectives = [instance['objective_description'][1] for instance in instances]
            nested_constraints = [instance['constraint_description'] for instance in instances]
            languages = [instance['language'] for instance in instances]

            language = self.vote_language(languages)
            variables, objectives, nested_constraints = self.vote_variables(
                variables, objectives, nested_constraints, languages, index)
            objective = self.vote_objective(objectives)
            constraints = self.vote_constraints(nested_constraints)

            if language == 'cn':
                var_des = '，'.join(list(variables.values()))
                var_def = '，'.join(list(variables.keys()))
                const_des = '，'.join(constraints)
                output = "变量：{}。分别定义为：{}。\n目标：{}。\n约束条件：{}。".format(
                    var_des, var_def, objective, const_des)
            else:
                var_des = ', '.join(list(variables.values()))
                var_def = ', '.join(list(variables.keys()))
                const_des = ', '.join(constraints)
                output = "The variables are: {}. Define them as: {}.\nThe objective is: {}.\nThe constraints are: {}.".format(
                    var_des, var_def, objective, const_des)
            voted_predictions.append({
                "index": index,
                "instruction": instances[0]['instruction'],
                "output": output
            })
            index += 1
        
        return voted_predictions


def test_Seq2SeqMetricsOnSeqIDs():
    import torch
    from transformers import LlamaTokenizer

    padding_side = 'left'
    base_model_path = '/root/model_ckpt/llama-7b/'
    tokenizer = LlamaTokenizer.from_pretrained(
        base_model_path, 
        padding_side=padding_side  # allow batched inference
        )
    tokenizer.pad_token_id = 0  # set pad_token to '<unk>'

    compute_metrics = Seq2SeqMetricsOnSeqIDs(
        ['accuracy', 'bleu', 'rouge'], tokenizer=tokenizer, padding_side=padding_side)
    
    tokenized_prompt = torch.load('../output/llama-7b_finetuning/debug/tokenized_prompt.pt')
    all_preds, all_labels = np.array([tokenized_prompt['input_ids']]), np.array([tokenized_prompt['labels']])
    eval_pred = EvalPrediction(predictions=all_preds, label_ids=all_labels)
    results = compute_metrics(eval_pred)
    print(results)


def test_MetricsOnMathFormula(true_path, log_path, error_path=None, eval_indexes=None, output_name='output'):
    with open(true_path, 'r') as f:
        trues = json.load(f)
    true_objs = [true["objective_description"][1] for true in trues]
    true_conses = [list(eval(true['constraint_description']).keys()) for true in trues]
    with open(log_path, 'r') as f:
        outputs = json.load(f)
    
    answers  = []
    for output in outputs:
        if isinstance(output[output_name], dict):
            answers.append(output[output_name]["response"])
        else:
            answers.append(output[output_name])
    # answers = [for output in outputs output["output"]["response"] if isinstance(output["output"], dict) else output["output"] ]

    metric_d = MetricsOnMathFormula(strict=False)
    pred_objs = []
    pred_conses = []
    parsed_true_objs = []
    for answer, true_obj in zip(answers, true_objs):
        try:
            parsed_true_obj= metric_d.parse_true_obj(true_obj)
        except:
            if not ('parsed_true_obj' in locals()):
                parsed_true_obj = ''
        pred_obj, pred_cons = metric_d.parse_formula_from_answer(answer)

        parsed_true_objs.append(parsed_true_obj)
        pred_objs.append(pred_obj)
        pred_conses.append(pred_cons)

    num_all_instances = min(len(pred_objs), len(pred_conses),len(parsed_true_objs), len(true_conses))
    if eval_indexes==None:
        eval_indexes = [i for i in range(num_all_instances)]
    acc, error_logs, right_logs = metric_d.compute_metric(pred_objs, parsed_true_objs, pred_conses, true_conses, eval_indexes)

    print('acc:',acc)
    print('true_num:', len(right_logs))
    print('error_num:', len(error_logs))
    format_err_num, obj_err_num, cons_err_num = 0, 0, 0
    for err_log in error_logs:
        if err_log["error"] == True:
            format_err_num+=1
        if err_log["eval_info"]["obj_flag"] == False:
            obj_err_num+=1
        if err_log["eval_info"]["cons_flag"] == False:
            cons_err_num+=1
    print('format error num:', format_err_num)
    print('objective error num:', obj_err_num)
    print('constraints error num:', cons_err_num)

    if error_path != None:
        print('Save error log to {}.'.format(error_path))
        with open(error_path, 'w', encoding='utf-8') as f: 
            f.write(json.dumps(error_logs, indent=2, ensure_ascii=False))


if __name__ == "__main__":
    # test_Seq2SeqMetricsOnSeqIDs()
    test_MetricsOnMathFormula()
    
    print('test over')