# Copyright 2020 IBM
# Author: peter.zhong@au1.ibm.com
#
# This is free software; you can redistribute it and/or modify
# it under the terms of the Apache 2.0 License.
#
# This software is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# Apache 2.0 License for more details.

import re
import ast
import json
import ipdb
import distance
from apted import APTED, Config
from itertools import product
from apted.helpers import Tree
from lxml import etree, html
from collections import deque
from parallel import parallel_process
from tqdm import tqdm
from zss import simple_distance, Node
import string
from typing import Any, Callable, Optional, Sequence
import numpy as np
import Levenshtein
import editdistance


class TableTree(Tree):
    def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
        self.tag = tag
        self.colspan = colspan
        self.rowspan = rowspan
        self.content = content
        self.children = list(children)

    def bracket(self):
        """Show tree using brackets notation"""
        if self.tag == 'td':
            result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
                     (self.tag, self.colspan, self.rowspan, self.content)
        else:
            result = '"tag": %s' % self.tag
        for child in self.children:
            result += child.bracket()
        return "{{{}}}".format(result)


class CustomConfig(Config):
    @staticmethod
    def maximum(*sequences):
        """Get maximum possible value
        """
        return max(map(len, sequences))

    def normalized_distance(self, *sequences):
        """Get distance from 0 to 1
        """
        return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)

    def rename(self, node1, node2):
        """Compares attributes of trees"""
        if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
            return 1.
        if node1.tag == 'td':
            if node1.content or node2.content:
                return self.normalized_distance(node1.content, node2.content)
        return 0.


class TEDS(object):
    ''' Tree Edit Distance basead Similarity
    '''
    def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
        assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1'
        self.structure_only = structure_only
        self.n_jobs = n_jobs
        self.ignore_nodes = ignore_nodes
        self.__tokens__ = []

    def tokenize(self, node):
        ''' Tokenizes table cells
        '''
        self.__tokens__.append('<%s>' % node.tag)
        if node.text is not None:
            self.__tokens__ += list(node.text)
        for n in node.getchildren():
            self.tokenize(n)
        if node.tag != 'unk':
            self.__tokens__.append('</%s>' % node.tag)
        if node.tag != 'td' and node.tail is not None:
            self.__tokens__ += list(node.tail)

    def load_html_tree(self, node, parent=None):
        ''' Converts HTML tree to the format required by apted
        '''
        global __tokens__
        if node.tag == 'td':
            if self.structure_only:
                cell = []
            else:
                self.__tokens__ = []
                self.tokenize(node)
                cell = self.__tokens__[1:-1].copy()
            new_node = TableTree(node.tag,
                                 int(node.attrib.get('colspan', '1')),
                                 int(node.attrib.get('rowspan', '1')),
                                 cell, *deque())
        else:
            new_node = TableTree(node.tag, None, None, None, *deque())
        if parent is not None:
            parent.children.append(new_node)
        if node.tag != 'td':
            for n in node.getchildren():
                self.load_html_tree(n, new_node)
        if parent is None:
            return new_node

    def evaluate(self, pred, true):
        ''' Computes TEDS score between the prediction and the ground truth of a
            given sample
        '''
        if (not pred) or (not true):
            return 0.0
        parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
        pred = html.fromstring(pred, parser=parser)
        true = html.fromstring(true, parser=parser)
        #print("pred:",pred)
        #print("true:",true)
        if pred.xpath('body/table') and true.xpath('body/table'):
            pred = pred.xpath('body/table')[0]
            true = true.xpath('body/table')[0]
            if self.ignore_nodes:
                etree.strip_tags(pred, *self.ignore_nodes)
                etree.strip_tags(true, *self.ignore_nodes)
            n_nodes_pred = len(pred.xpath(".//*"))
            n_nodes_true = len(true.xpath(".//*"))
            n_nodes = max(n_nodes_pred, n_nodes_true)
            tree_pred = self.load_html_tree(pred)
            tree_true = self.load_html_tree(true)
            distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
            return 1.0 - (float(distance) / n_nodes)
        else:
            return 0.0

    def batch_evaluate(self, pred_json, true_json):
        ''' Computes TEDS score between the prediction and the ground truth of
            a batch of samples
            @params pred_json: {'FILENAME': 'HTML CODE', ...}
            @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
            @output: {'FILENAME': 'TEDS SCORE', ...}
        '''
        samples = true_json.keys()
        if self.n_jobs == 1:
            scores = [self.evaluate(pred_json.get(filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
        else:
            #inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]['html']} for filename in samples]
            inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]} for filename in samples]
            scores = parallel_process(inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
        scores = dict(zip(samples, scores))
        return scores


def convert_table_to_html_str(table_row_list=[]):
    """
    Given a list of table rows, build the corresponding html string, which is used to compute the TEDS score.
    We use the official code of PubTabNet to compute TEDS score, it does not consider '<th>' label.
    We also remove unneccessary spaces within a table cell and extra '\n' as they will influence the TEDS score.
    """
    html_table_str = "<html><body><table>" + '\n'
    for data_row in table_row_list:
        html_table_str += "<tr>"
        for cell_str in data_row:
            html_table_str += f"<td>{cell_str}</td>"
        html_table_str += "</tr>"
        html_table_str += '\n'
    html_table_str += "</table></body></html>"
    html_table_str = html_table_str.replace('\n','')
    return html_table_str


def convert_markdown_table_to_html(markdown_table):
    """
    Converts a markdown table to the corresponding html string for TEDS computation.
    """
    # remove extra code block tokens like '```markdown' and '```
    markdown_table = markdown_table.strip('```markdown').strip('```').strip() 
    row_str_list = markdown_table.split('\n')
    # extra the first header row and other data rows
    valid_row_str_list = [row_str_list[0]]+row_str_list[2:]
    table_rows = []
    for row_str in valid_row_str_list:
        one_row = []
        for cell in row_str.strip().split('|')[1:-1]:
            if set(cell) != set(' '):
                one_row.append(cell.strip())
            else:
                one_row.append(' ')
        table_rows.append(one_row)
    # build html string based on table rows
    html_str = convert_table_to_html_str(table_rows)
    return html_str


def dict_to_html(data):
    html = "<html><body><table>\n"
    for key, value in data.items():
        if not isinstance(value, str):
            value = str(value)
        value_str = ' '.join(value)
        
        html += f"  <tr><td>{key}</td><td>{value_str}</td></tr>\n"
    html += "</table></body></html>"
    return html


def convert_str_to_dict(predict_str: str):
    """
    Parses the 'predict' string and returns a dictionary.
    Missing or unparseable content is handled gracefully.

    Parameters:
    - predict_str (str): The prediction string containing the output dict.

    Returns:
    - dict: A dictionary extracted from the predict string.
    """
    # Remove code fences like ```python\n...\n```
    code_fence_pattern = r'```(?:python|json)?\n(.*?)\n```'
    match = re.search(code_fence_pattern, predict_str, re.DOTALL | re.IGNORECASE)
    if match:
        content = match.group(1)
    else:
        content = predict_str.strip()

    data = {}
    success = False

    # try parsing with JSON
    try:
        data = json.loads(content)
        success = True
    except json.JSONDecodeError:
        pass

    # try parsing with ast.literal_eval
    if not success:
        try:
            data = ast.literal_eval(content)
            if isinstance(data, dict):
                success = True
        except (ValueError, SyntaxError):
            pass

    # try parsing with regex
    if not success:
        key_value_pattern = r'["\']?([\w\s]+)["\']?\s*[:=]\s*["\']?([^\n,"\'{}]+)["\']?'
        matches = re.findall(key_value_pattern, content)
        try:
            for key, value in matches:
                data[key.strip()] = value.strip()
        except:
            return {}

    if not data:
        return {}

    try:
        result = {k.strip(): str(v).strip() for k, v in data.items()}
    except:
        return {}
    return result


def convert_str_to_multi_dict(predict_str: str):
    """
    Parses the 'predict' string and returns a dictionary.
    Handles nested dictionaries and missing or unparseable content gracefully.

    Parameters:
    - predict_str (str): The prediction string containing the output dict.

    Returns:
    - dict: A dictionary extracted from the predict string.
    """
    # Remove code fences like ```python\n...\n```
    code_fence_pattern = r'```(?:python|json)?\n(.*?)\n```'
    matches = re.findall(code_fence_pattern, predict_str, re.DOTALL | re.IGNORECASE)
    if matches:
        content = max(matches, key=len)
    else:
        content = predict_str.strip()
    
    def strip_variable_assignment(s):
        variable_assignment_pattern = r'^\s*\w+\s*=\s*'
        return re.sub(variable_assignment_pattern, '', s.strip(), count=1)

    content = strip_variable_assignment(content)

    def remove_comments(s):
        return re.sub(r'#.*', '', s)

    content = remove_comments(content)

    last_brace_pos = content.rfind('}')
    if last_brace_pos != -1:
        content = content[:last_brace_pos+1]

    data = {}
    success = False

    # try parsing with ast.literal_eval
    try:
        data = ast.literal_eval(content)
        if isinstance(data, dict):
            success = True
    except (ValueError, SyntaxError, TypeError):
        pass

    if not success:
        return {}

    def process_data(obj):
        if isinstance(obj, dict):
            return {k: process_data(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [process_data(elem) for elem in obj]
        else:
            return obj

    data = process_data(data)

    return data


def generate_combinations(input_dict):
    """
    Function to generate all possible combinations of values from a dictionary.
    """
    kie_answer = input_dict
    if not isinstance(kie_answer, dict):
        kie_answer = kie_answer.strip('"')
        try:
            kie_answer = json.loads(kie_answer)
        except json.JSONDecodeError:
            try:
                kie_answer = ast.literal_eval(kie_answer)
                if not isinstance(kie_answer, dict):
                    kie_answer = ast.literal_eval(kie_answer)
            except (ValueError, SyntaxError):
                print(f"Unable to parse 'answers' field: {kie_answer}")
                return {}
        
        # Ensure the parsed result is a dictionary.
        if not isinstance(kie_answer, dict):
            print("Parsed 'answers' is still not a dictionary.")
            raise ValueError("Input could not be parsed into a dictionary.")
    
        keys = list(kie_answer.keys())
        
        value_lists = []
        for single_key in keys:
            sinlge_value = kie_answer[single_key]
            if not isinstance(sinlge_value, list):
                sinlge_value = [sinlge_value]
            value_lists.append(sinlge_value)
    
        # Compute the Cartesian product of the value lists.
        combinations = list(product(*value_lists))
    
        # Create a dictionary for each combination of values.
        result = [dict(zip(keys, values)) for values in combinations]

        return result
    
    else:
        keys = list(input_dict.keys())
        value_lists = [input_dict[key] for key in keys]

        # Compute the Cartesian product of the value lists.
        combinations = list(product(*value_lists))

        # Create a dictionary for each combination of values.
        result = [dict(zip(keys, values)) for values in combinations]

        return result


def compute_f1_score(preds, gts, ignores=[]):
    """Compute the F1-score for KIE task between predicted and ground truth dictionaries.

    Args:
        preds (dict): The predicted key-value pairs.
        gts (dict): The ground truth key-value pairs.
        ignores (list): The list of keys to ignore during evaluation.

    Returns:
        dict: A dictionary where keys are field names and values are their corresponding F1-scores.
    """
    # Optionally remove ignored keys from predictions and ground truths
    keys = set(preds.keys()).union(set(gts.keys())) - set(ignores)
    f1_scores = {}

    for key in keys:
        pred_value = preds.get(key, None)
        gt_value = gts.get(key, None)

        if pred_value:
            pred_value = pred_value.lower().strip().replace("\n"," ").replace(" ", "")
        if gt_value:
            gt_value = gt_value.lower().strip().replace("\n"," ").replace(" ", "")

        if pred_value is None and gt_value is None:
            continue
        elif pred_value is None:
            precision = 0.0
            recall = 0.0
        elif gt_value is None:
            # false positive
            precision = 0.0
            recall = 0.0
        else:
            if pred_value == gt_value:
                # True positive
                precision = 1.0
                recall = 1.0
            else:
                precision = 0.0
                recall = 0.0

        # Compute F1-score
        f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        f1_scores[key] = f1_score

    if len(f1_scores) == 0:
        return 0
    average_f1 = sum(f1_scores.values()) / len(f1_scores)

    return average_f1


def pre_clean(text):
    text = re.sub(r'<bos>|<eos>|<pad>|<unk>', '', text)
    text = re.sub(r'\s##(\S)', r'\1', text)
    text = re.sub(r'\\\s', r'\\', text)
    text = re.sub(r'\s\*\s\*\s', r'**', text)
    text = re.sub(r'{\s', r'{', text)
    text = re.sub(r'\s}', r'}', text)
    text = re.sub(r'\s}', r'}', text)
    text = re.sub(r'\\begin\s', r'\\begin', text)
    text = re.sub(r'\\end\s', r'\\end', text)
    text = re.sub(r'\\end{table}', r'\\end{table} \n\n', text)
    text = text.replace('\n', ' ')
    text = text.replace('*', ' ')
    text = text.replace('_', ' ')
    return text


def get_tree(input_str):
    tree = (Node('ROOT').addkid(Node('TITLE')))

    lines = input_str.split("\n")
    lines = [pre_clean(line) for line in lines]
    last_title = ''
    for line in lines:
        if line.startswith('#'):
            child = tree.get('ROOT')
            line = line.replace('#', '')
            child.addkid(Node(line))
            last_title = line
        else:
            if last_title == '':
                child = tree.get('TITLE')
                child.addkid(Node(line))
            else:
                child = tree.get(last_title)
                child.addkid(Node(line))
    return tree

def STEDS(pred_tree, ref_tree):
    def my_distance(pred, ref):
        if len(pred.split()) == 0 or len(ref.split()) == 0:
            return 1
        else:
            return 0
    total_distance = simple_distance(pred_tree, ref_tree, label_dist=my_distance)
    num_of_nodes = max(len(list(pred_tree.iter())), len(list(ref_tree.iter())))
    return 1-total_distance/num_of_nodes


def doc_parsing_evaluation(pred, gt):
    score = 0
    if not isinstance(pred, str):
        return 0
    pred_tree = get_tree(pred)
    gt_tree = get_tree(gt)
    score = STEDS(pred_tree, gt_tree)

    return score


def wrap_html_table(html_table):
    """
    The TEDS computation from PubTabNet code requires that the input html table should have <html>, <body>, and <table> tags.
    Add them if they are missing.
    """
    html_table = html_table.replace('\n','')
    # add missing <table> tag if missing
    if "<table" in html_table and "</table>" not in html_table:
        html_table = html_table + "</table>"
    elif "<table" not in html_table and "</table>" in html_table:
        html_table = "<table>" + html_table
    elif "<table" not in html_table and "</table>" not in html_table:
        html_table = "<table>" + html_table + "</table>"
    else:
        pass
    # add <body> and <html> tags if missing
    if '<body>' not in html_table:
        html_table = '<body>' + html_table + '</body>'
    if '<html>' not in html_table:
        html_table = '<html>' + html_table + '</html>'
    return html_table
    

def get_anls(s1, s2):
    try:
        s1 = s1.lower()
        s2 = s2.lower()
    except:
        pass
    if s1 == s2:
        return 1.0
    iou = 1 - editdistance.eval(s1, s2) / max(len(s1), len(s2))
    anls = iou
    return anls


def ocr_eval(references,predictions):
    socre_=0.0
    None_num=0
    for idx,ref_value in enumerate(references):
        pred_value = predictions[idx]
        pred_values, ref_values = [], []
        if isinstance(pred_value, str):
            pred_values.append(pred_value)
        else:
            pred_values = pred_value
        if isinstance(ref_value, str):
            ref_values.append(ref_value)
        else:
            ref_values = ref_value
        
        temp_score = 0.0
        temp_num = len(ref_values)
        
        for tmpidx, tmpref in enumerate(ref_values):
            tmppred = pred_values[tmpidx] if tmpidx < len(pred_values) else pred_values[0]
            if len(pred_values) == 1 and tmppred != "None" and "None" not in ref_values:  # pred 1, and not None
                temp_score = max(temp_score, get_anls(tmppred, tmpref))
                temp_num = len(ref_values)
            else:
                if tmppred=='None' and tmpref!='None':
                    temp_score += 0.0
                elif tmpref=='None':
                    temp_num -= 1
                else:
                    temp_score += get_anls(tmppred, tmpref)
        if temp_num == 0:
            ocr_score = 0.0
            None_num += 1
        else:
            ocr_score = temp_score / (temp_num)
        socre_ += ocr_score
    if None_num == len(references):
        return 9999
    else:
        return round(socre_ / (len(references)-None_num), 5)


def csv_eval(predictions,references,easy, pred_type='json'):
    predictions = predictions
    labels = references
    def is_int(val):
        try:
            int(val)
            return True
        except ValueError:
            return False

    def is_float(val):
        try:
            float(val)
            return True
        except ValueError:
            return False
    
    def convert_dict_to_list(data):
        """
        Convert a dictionary to a list of tuples, handling both simple and nested dictionaries.
        
        Args:
        data (dict): The input dictionary, which might be nested or simple.
        
        Returns:
        list: A list of tuples generated from the input dictionary.
        """
        # print(data)
        converted_list = []
        for key, value in data.items():
            # Check if the value is a dictionary (indicating a nested structure)
            if isinstance(value, dict):
                # Handle nested dictionary
                for subkey, subvalue in value.items():
                    # converted_list.append((key, subkey, subvalue))
                    converted_list.append((key, subkey, re.sub(r'[^\d.-]', '', str(subvalue))))

            else:
                # Handle simple key-value pair
                # converted_list.append((key, "value", value))
                converted_list.append((key, "value", re.sub(r'[^\d.-]', '', str(value))))
        return converted_list


    def csv2triples(csv, separator='\\t', delimiter='\\n'):  
        lines = csv.strip().split(delimiter)
        header = lines[0].split(separator) 
        triples = []
        for line in lines[1:]:   
            if not line:
                continue
            values = line.split(separator)
            entity = values[0]
            for i in range(1, len(values)):
                if i >= len(header):
                    break
                #---------------------------------------------------------
                temp = [entity.strip(), header[i].strip()]
                temp = [x if len(x)==0 or x[-1] != ':' else x[:-1] for x in temp]
                value = values[i].strip()
                value = re.sub(r'[^\d.-]', '', str(value))
                # value = value.replace("%","")     
                # value = value.replace("$","")     
                triples.append((temp[0], temp[1], value))
                #---------------------------------------------------------
        return triples
    
    def csv2triples_noheader(csv, separator='\\t', delimiter='\\n'):  
        lines = csv.strip().split(delimiter)
        maybe_header = [x.strip() for x in lines[0].split(separator)]
        not_header = False
        if len(maybe_header) > 2:
            for c in maybe_header[1:]:
                try:
                    num = float(c)
                    not_header = True
                except:
                    continue
                if not_header:
                    break
        header = None if not_header else maybe_header
        data_start = 0 if not_header and separator in lines[0] else 1
        triples = []
        for line in lines[data_start:]:   
            if not line:
                continue
            values = [x.strip() for x in line.split(separator)]
            entity = values[0]
            for i in range(1, len(values)):
                try:
                    temp = [entity if entity[-1]!=':' else entity[:-1], ""]
                except:
                    temp = [entity, ""]
                if header is not None:
                    try:
                        this_header = header[i]
                        temp = [entity, this_header]
                        temp = [x if x[-1] != ':' else x[:-1] for x in temp]
                    except:
                        this_header = entity.strip()
                value = values[i].strip()
                value = re.sub(r'[^\d.-]', '', str(value))
                # value = value.replace("%","")     
                # value = value.replace("$","")     
                triples.append((temp[0], temp[1], value))
                #---------------------------------------------------------
        return triples

    def process_triplets(triplets):
        new_triplets = []
        for triplet in triplets:
            new_triplet = []
            triplet_temp = []
            if len(triplet) > 2:
                if is_int(triplet[2]) or is_float(triplet[2]):
                    triplet_temp = (triplet[0].lower(), triplet[1].lower(), float(triplet[2]))
                else:
                    triplet_temp = (triplet[0].lower(), triplet[1].lower(), triplet[2].lower())
            else: 
                triplet_temp = (triplet[0].lower(), triplet[1].lower(), "no meaning")
            new_triplets.append(triplet_temp)
        return new_triplets

    def intersection_with_tolerance(a, b, tol_word, tol_num):
        a = set(a)
        b = set(b)
        c = set()
        for elem1 in a:
            for elem2 in b:
                if is_float(elem1[-1]) and is_float(elem2[-1]):
                    if ((Levenshtein.distance(''.join(elem1[:-1]),''.join(elem2[:-1])) <= tol_word) and (abs(elem1[-1] - elem2[-1]) / (abs(elem2[-1])+0.000001) <= tol_num))or \
                    ((''.join(elem1[:-1]) in ''.join(elem2[:-1])) and (abs(elem1[-1] - elem2[-1]) / (abs(elem2[-1])+0.000001) <= tol_num)) or \
                    ((''.join(elem2[:-1]) in ''.join(elem1[:-1])) and (abs(elem1[-1] - elem2[-1]) / (abs(elem2[-1])+0.000001) <= tol_num)):
                        c.add(elem1)
                else:
                    if (Levenshtein.distance(''.join([str(i) for i in elem1]),''.join([str(j) for j in elem2])) <= tol_word):
                        c.add(elem1)
        return list(c)

    def union_with_tolerance(a, b, tol_word, tol_num):
        c = set(a) | set(b)
        d = set(a) & set(b)
        e = intersection_with_tolerance(a, b, tol_word, tol_num)
        f = set(e)
        g = c-(f-d)
        return list(g)

    def get_eval_list(pred_csv, label_csv, separator='\\t', delimiter='\\n', tol_word=3, tol_num=0.05, pred_type='json'):

        if pred_type == 'json':
            pred_triple_list=[]
            for it in pred_csv:
                pred_triple_temp = convert_dict_to_list(it)
                pred_triple_pre = process_triplets(pred_triple_temp)
                pred_triple_list.append(pred_triple_pre) 
        else:
            pred_triple_list=[]
            for it in pred_csv:
                pred_triple_temp = csv2triples(it, separator=separator, delimiter=delimiter)
                # pred_triple_temp = csv2triples_noheader(it, separator=separator, delimiter=delimiter)
                pred_triple_pre = process_triplets(pred_triple_temp)
                pred_triple_list.append(pred_triple_pre) 

        label_triple_list=[]
        for it in label_csv:
            label_triple_temp = convert_dict_to_list(it)
            label_triple_pre = process_triplets(label_triple_temp)
            label_triple_list.append(label_triple_pre) 

            
        intersection_list=[]
        union_list=[]
        sim_list=[]
        # for each chart image
        for pred,label in zip(pred_triple_list, label_triple_list):
            for idx in range(len(pred)):
                try:
                    if label[idx][1] == "value" and "value" not in pred[idx][:2]:
                        pred[idx] = (pred[idx][0], "value", pred[idx][2]) 
                    temp_pred_head = sorted(pred[idx][:2])
                    temp_gt_head = sorted(label[idx][:2])
                    pred[idx] = (temp_pred_head[0], temp_pred_head[1], pred[idx][2])
                    label[idx] = (temp_gt_head[0], temp_gt_head[1], label[idx][2])
                except:
                    continue
            intersection = intersection_with_tolerance(pred, label, tol_word = tol_word, tol_num=tol_num)
            union = union_with_tolerance(pred, label, tol_word = tol_word, tol_num=tol_num)
            sim = len(intersection)/len(union)
            intersection_list.append(intersection)
            union_list.append(union)
            sim_list.append(sim)
        return intersection_list, union_list, sim_list

    def get_ap(predictions, labels, sim_threhold, tolerance, separator='\\t', delimiter='\\n', easy=1):
        if tolerance == 'strict':
            tol_word=0
            if easy == 1:
                tol_num=0
            else:
                tol_num=0.1

        elif tolerance == 'slight':
            tol_word=2
            if easy == 1:
                tol_num=0.05
            else:
                tol_num=0.3

        elif tolerance == 'high':
            tol_word= 5
            if easy == 1:
                tol_num=0.1
            else:
                tol_num=0.5      
        intersection_list, union_list, sim_list = get_eval_list(predictions, labels, separator=separator, delimiter=delimiter, tol_word=tol_word, tol_num=tol_num, pred_type=pred_type)
        ap = len([num for num in sim_list if num >= sim_threhold])/(len(sim_list)+1e-16)
        return ap   

    map_strict = 0
    map_slight = 0
    map_high = 0
    s="\\t"
    d="\\n"

    for sim_threhold in np.arange (0.5, 1, 0.05):
        map_temp_strict = get_ap(predictions, labels, sim_threhold=sim_threhold, tolerance='strict', separator=s, delimiter=d, easy=easy)
        map_temp_slight = get_ap(predictions, labels, sim_threhold=sim_threhold, tolerance='slight', separator=s, delimiter=d, easy=easy)
        map_temp_high = get_ap(predictions, labels, sim_threhold=sim_threhold, tolerance='high', separator=s, delimiter=d, easy=easy)
        map_strict += map_temp_strict/10
        map_slight += map_temp_slight/10
        map_high += map_temp_high/10

    em = get_ap(predictions, labels, sim_threhold=1, tolerance='strict', separator=s, delimiter=d, easy=easy)
    ap_50_strict = get_ap(predictions, labels, sim_threhold=0.5, tolerance='strict', separator=s, delimiter=d, easy=easy)
    ap_75_strict = get_ap(predictions, labels, sim_threhold=0.75, tolerance='strict', separator=s, delimiter=d, easy=easy)    
    ap_90_strict = get_ap(predictions, labels, sim_threhold=0.90, tolerance='strict', separator=s, delimiter=d, easy=easy)
    ap_50_slight = get_ap(predictions, labels, sim_threhold=0.5, tolerance='slight', separator=s, delimiter=d, easy=easy)
    ap_75_slight = get_ap(predictions, labels, sim_threhold=0.75, tolerance='slight', separator=s, delimiter=d, easy=easy)    
    ap_90_slight = get_ap(predictions, labels, sim_threhold=0.90, tolerance='slight', separator=s, delimiter=d, easy=easy)
    ap_50_high = get_ap(predictions, labels, sim_threhold=0.5, tolerance='high', separator=s, delimiter=d, easy=easy)
    ap_75_high = get_ap(predictions, labels, sim_threhold=0.75, tolerance='high', separator=s, delimiter=d, easy=easy)    
    ap_90_high = get_ap(predictions, labels, sim_threhold=0.90, tolerance='high', separator=s, delimiter=d, easy=easy)


    return em, map_strict, map_slight, map_high, ap_50_strict, ap_75_strict, ap_90_strict, ap_50_slight, ap_75_slight, ap_90_slight, ap_50_high, ap_75_high, ap_90_high

def draw_SCRM_table(em, map_strict, map_slight, map_high, ap_50_strict, ap_75_strict, ap_90_strict, ap_50_slight, ap_75_slight, ap_90_slight, ap_50_high, ap_75_high, ap_90_high,title_ocr_socre,source_ocr_socre,x_title_ocr_socre,y_title_ocr_socre,structure_accuracy):

    result=f'''
            -----------------------------------------------------------\n
            |  Metrics   |  Sim_threshold  |  Tolerance  |    Value    |\n
            -----------------------------------------------------------\n
            |             |                 |   strict    |    {'%.4f' % map_strict}    |     \n
            |             |                 ----------------------------\n
            |  mPrecison  |  0.5:0.05:0.95  |   slight    |    {'%.4f' % map_slight}    |\n
            |             |                  ---------------------------\n
            |             |                 |    high     |    {'%.4f' % map_high}    |\n
            -----------------------------------------------------------\n
            |             |                 |   strict    |    {'%.4f' % ap_50_strict}    |\n
            |             |                  ---------------------------\n
            |  Precison   |       0.5       |   slight    |    {'%.4f' % ap_50_slight }    |\n
            |             |                  ---------------------------\n
            |             |                 |    high     |    {'%.4f' % ap_50_high }    |\n
            -----------------------------------------------------------\n
            |             |                 |   strict    |    {'%.4f' % ap_75_strict}    |\n
            |             |                  ---------------------------\n
            |  Precison   |      0.75       |   slight    |    {'%.4f' % ap_75_slight}    |\n
            |             |                  ---------------------------\n
            |             |                 |    high     |    {'%.4f' % ap_75_high}    |\n
            -----------------------------------------------------------\n
            |             |                 |   strict    |    {'%.4f' % ap_90_strict}    |\n
            |             |                  ---------------------------\n
            |  Precison   |       0.9       |   slight    |    {'%.4f' % ap_90_slight }    |\n
            |             |                  ---------------------------\n
            |             |                 |    high     |    {'%.4f' % ap_90_high}    |\n
            -----------------------------------------------------------\n
            |Precison(EM) |                                    {'%.4f' % em}    |\n
            -----------------------------------------------------------\n
            |Title(EM)    |                                    {'%.4f' % title_ocr_socre}    |\n
            -----------------------------------------------------------\n
            |Source(EM)   |                                    {'%.4f' % source_ocr_socre}    |\n
            -----------------------------------------------------------\n
            |X_title(EM)  |                                    {'%.4f' % x_title_ocr_socre}    |\n
            -----------------------------------------------------------\n
            |Y_title(EM)  |                                    {'%.4f' % y_title_ocr_socre}    |\n
            -----------------------------------------------------------\n
            |structure_acc|                                    {'%.4f' % structure_accuracy}    |\n
            -----------------------------------------------------------\n


            '''
    return result


if __name__ == '__main__':
    import json
    import pprint

    # markdown structure for Table Parsing task
    pred_markdown = "| 1 | august 5 , 1972 | detroit lions | l 23 - 31 | 0 - 1 |\n| 2 | august 12 , 1972 | green bay packers | l 13 - 14 | 0 - 2 |\n| 3 | august 19 , 1972 | cincinnati bengals | w 35 - 17 | 1 - 2 |\n| 4 | august 25 , 1972 | atlanta falcons | w 24 - 10 | 2 - 2 |\n| 5 | august 31 , 1972 | washington redskins | l 24 - 27 | 2 - 3 |\n| 6 | september 10 , 1972 | minnesota vikings | w 21 - 19 | 3 - 3 |"
    true_markdown = "| week | date | opponent | result | record |\n| --- | --- | --- | --- | --- |\n| 1 | august 5 , 1972 | detroit lions | l 23 - 31 | 0 - 1 |\n| 2 | august 12 , 1972 | green bay packers | l 13 - 14 | 0 - 2 |\n| 3 | august 19 , 1972 | cincinnati bengals | w 35 - 17 | 1 - 2 |\n| 4 | august 25 , 1972 | atlanta falcons | w 24 - 10 | 2 - 2 |\n| 5 | august 31 , 1972 | washington redskins | l 24 - 27 | 2 - 3 |\n| 6 | september 10 , 1972 | minnesota vikings | w 21 - 19 | 3 - 3 |"
    teds = TEDS(n_jobs=4)
    pred_table_html = convert_markdown_table_to_html(pred_markdown)
    true_table_html = convert_markdown_table_to_html(true_markdown)

    scores = teds.evaluate(pred_table_html, true_table_html)

    pp = pprint.PrettyPrinter()
    pp.pprint(scores)

    # dict structure for Key Information Extraction task
    pred_dict = {
            "company": [
                "OLD TOWN "
            ],
            "date": [
                "2024"
            ],
            "address": [
                "SRI RAMPAI"
            ],
            "total": [
                "30"
            ]
        }
    true_dict = {
            "company": [
                "OLD TOWN KOPITAM SND BHD"
            ],
            "date": [
                "2024/9/27"
            ],
            "address": [
                "SRI RAMPAI"
            ],
            "total": [
                "30"
            ]
        }
    teds = TEDS(n_jobs=4)
    pred_dict_html = dict_to_html(pred_dict)
    true_dict_html = dict_to_html(true_dict)
    print(pred_dict_html)
    print(true_dict_html)

    scores = teds.evaluate(pred_dict_html, true_dict_html)

    pp = pprint.PrettyPrinter()
    pp.pprint(scores)
