import re
import editdistance
import sqlite3
from collections import Counter
from func_timeout import func_timeout, FunctionTimedOut
import multiprocessing as mp
from tqdm import tqdm
from pebble import ProcessPool
from src.utils.utils import load_json
from src.api.mysql_api import MySQLAPI
from src.data_construct.math_veirfy import is_equiv
from src.data_construct.dart_eval import EvaluatorMath

class Verifier:
    def __init__(self, config) -> None:
        self.config = config

    def execute(self, statement):
        raise NotImplementedError
    
    def verify(self, gold, pred):
        raise NotImplementedError

    def edit_distance(self, gold, pred):
        return editdistance.eval(gold, pred)

class SQLVerifier(Verifier):
    def __init__(self, config):
        self.api = MySQLAPI(config)

    def post_process_sql(self, sql):
        sql = sql.strip()
        if '```sql' in sql:
            sql_re = re.findall('```sql(.*?)```', sql, re.DOTALL)
            sql = sql_re[0].strip() if len(sql_re) > 0 else ''
        if '```' in sql:
            sql = sql.replace('```', '')
        sql = re.sub('\s+', ' ', sql)
        return sql.strip()

    def execute_sql(self, sql):
        try:
            self.api.cur.execute(sql)
            records = self.api.cur.fetchall()
            return records, ""
        except Exception as e:
            return None, "Error: " + str(e)
    
    def verify(self, gold_sql, pred_sql):
        return Counter(gold_sql['res']) == Counter(pred_sql['res'])

    def execute(self, sql):
        sql = self.post_process_sql(sql)
        sql_res, sql_err = self.execute_sql(sql)
        exec_status = False if 'Error' in sql_err else True
        
        exec_res = {
            'output': sql,
            'res': sql_res,
            'err': sql_err,
            'exec_status': exec_status
        }
        return exec_res
    
class BIRDVerifier(Verifier):
    def __init__(self, config, pbar=False) -> None:
        self.config = config
        self.timeout = 20
        self.n_procs = config['num_cpus']
        self.pbar = pbar

    def post_process_sql(self, sql):
        sql, db_name = sql.split('\t----- bird -----\t')
        db_name = self.config['db_root_path'] + db_name + '/' + db_name + '.sqlite'
        sql = sql.strip()
        if '```sql' in sql:
            sql_re = re.findall('```sql(.*?)```', sql, re.DOTALL)
            sql = sql_re[0].strip() if len(sql_re) > 0 else ''
        if '```' in sql:
            sql = sql.replace('```', '')
        sql = re.sub('\s+', ' ', sql)
        return sql.strip(), db_name
    
    def execute_sql(self, sql, db_path):
        try:
            conn = sqlite3.connect(db_path)
            cursor = conn.cursor()
            cursor.execute(sql)
            res = cursor.fetchall()
        except Exception as e:
            return None, "Error: " + str(e)
        return res, ""

    def execute(self, sql):
        sql, db_path = self.post_process_sql(sql)
        sql_res, sql_err = self.execute_sql(sql, db_path)
        exec_status = False if 'Error' in sql_err else True
        
        exec_res = {
            'output': sql,
            'res': sql_res,
            'err': sql_err,
            'exec_status': exec_status
        }
        return exec_res
    
    def parallel_execute_and_verify(self, sqls):
        all_sqls = [com[2] for com in sqls]
        all_gt = [com[1] for com in sqls]
        n_samples = len(sqls)
        exec_answers = []
        with ProcessPool(max_workers=min(self.n_procs, n_samples), max_tasks=1024) as pool:
            iterator = pool.map(self.execute, all_sqls, timeout=self.timeout).result()
            pbar = tqdm(total=n_samples, desc="Extracting") if self.pbar else None
            while True:
                try:
                    answer = next(iterator)
                    exec_answers.append(answer['res'])
                except StopIteration:
                    break
                except TimeoutError:
                    exec_answers.append("")
                except Exception:
                    exec_answers.append("")
                if pbar:
                    pbar.update(1)
            if pbar:
                pbar.close()

            eval_samples = [(gt, pred) for gt, pred in zip(all_gt, exec_answers)]
            # verify
            iterator = pool.map(self.verify, eval_samples, timeout=self.timeout).result()
            pbar = tqdm(total=n_samples, desc="Evaluating") if self.pbar else None
            corrects = []
            while True:
                try:
                    result = next(iterator)
                    corrects.append(result)
                except StopIteration:
                    break
                except Exception:
                    corrects.append(False)
                if pbar:
                    pbar.update(1)
            if pbar:
                pbar.close()

            corrects = [bool(correct) for correct in corrects]
        return corrects
    
    def verify(self, gold_pred):
        gold, pred = gold_pred
        ans = self.execute(gold)['res']
        return ans is not None and pred is not None and set(ans) == set(pred)

class GSM8KMathVerifier(Verifier):
    def __init__(self, config):
        self.ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
        self.regex_pattern = r"(-?[0-9]+(?:\.[0-9]+)?)"
        self.regex = re.compile(self.regex_pattern)
        self.INVALID_ANS = "[invalid]"
        self.evaluator = EvaluatorMath()

    def post_process_answer(self, completion):
        match = self.regex.findall(completion)
        if match:
            match = match[-1] 
            match = match.strip()
            match = match.replace(',', '') 
            if '.' in match:
                match = str(int(float(match)))  
            else:
                match = str(int(match))
            
            # 格式化为带逗号的字符串
            match = f"{int(match):,}"
        else:
            match = self.INVALID_ANS
        return match

    def execute(self, completion):
        def convert_to_number(value: str) -> int:
            return int(value.replace(',', ''))
        
        answer = self.post_process_answer(completion)
        
        if "####" in answer:
            answer = answer.split("####")[-1].strip()
        if answer != self.INVALID_ANS:
            answer = convert_to_number(answer)
        exec_status = False if answer == self.INVALID_ANS else True

        # answer = self.evaluator.extract_ans(completion)
        # exec_status = False if answer == "" else True

        exec_res = {
            'output': completion,
            'res': answer,
            'exec_status': exec_status
        }
        return exec_res

    def verify(self, gold_answer, pred_answer):
        return gold_answer['res'] == pred_answer['res']
        # return self.evaluator.eq(ref=gold_answer['res'], pred=pred_answer['res'])
    

class MathVerifier(Verifier):
    def __init__(self, config, pbar=False) -> None:
        self.config = config
        self.n_procs = 2
        self.timeout = 5
        self.pbar = pbar
        self.evaluator = EvaluatorMath()

    def post_process_answer(self, completion):
        def _strip_string(string: str) -> str:
            replacements = {
                "\n": "",
                "\\!": "",
                "\\\\": "\\",
                "tfrac": "frac",
                "dfrac": "frac",
                "\\left": "",
                "\\right": "",
                "^{\\circ}": "",
                "^\\circ": "",
                "\\$": "",
                "\\%": "",
                "%": "",
                " .": " 0.",
                "{.": "{0."
            }
            for old, new in replacements.items():
                string = string.replace(old, new)
            
            if string.startswith("."):
                string = "0" + string

            string = re.sub(r'\s+', '', string) 
            return string

        if "boxed{" in completion:
            start = completion.find("boxed{") + len("boxed{")
            end = completion.rfind("}")
            if end != -1 and start < end:
                boxed_content = completion[start:end].strip()
                return _strip_string(boxed_content)
        return ""

    def execute(self, completion):
        # answer = self.post_process_answer(completion)
        # exec_status = False if answer == "" else True

        answer = self.evaluator.extract_ans(completion)
        exec_status = False if answer == "" else True
        exec_res = {
            'output': completion,
            'res': answer,
            'exec_status': exec_status
        }
        return exec_res
    
    def parallel_execute_and_verify(self, completions):
        """completions: (i, gt, answer, type)"""
        n_samples = len(completions)
        all_answers = [com[2] for com in completions]
        all_gt = [com[1] for com in completions]
        exec_answers = []
        with ProcessPool(max_workers=min(self.n_procs, n_samples), max_tasks=1024) as pool:
            # execute
            iterator = pool.map(self.evaluator.extract_ans, all_answers, timeout=self.timeout).result()
            pbar = tqdm(total=n_samples, desc="Extracting") if self.pbar else None
            while True:
                try:
                    answer = next(iterator)
                    exec_answers.append(answer)
                except StopIteration:
                    break
                except Exception:
                    exec_answers.append("")
                if pbar:
                    pbar.update(1)
            if pbar:
                pbar.close()

            eval_samples = [(gt, pred) for gt, pred in zip(all_gt, exec_answers)]
            # verify
            iterator = pool.map(self.verify, eval_samples, timeout=self.timeout).result()
            pbar = tqdm(total=n_samples, desc="Evaluating") if self.pbar else None
            corrects = []
            while True:
                try:
                    result = next(iterator)
                    corrects.append(result)
                except StopIteration:
                    break
                except Exception:
                    corrects.append(False)
                if pbar:
                    pbar.update(1)
            if pbar:
                pbar.close()

            corrects = [bool(correct) for correct in corrects]
        return corrects
    
    def verify(self, gold_pred):
        gold, pred = gold_pred
        ans = self.evaluator.extract_ans(gold)
        correct = self.evaluator.eq(pred, ans)
        return correct

class CodeVerifier(Verifier):
    def __init__(self, config):
        pass

    def verify(self, gold_code, pred_code):
        pass
