from typing import List, Tuple, Any
from dotenv import load_dotenv
load_dotenv()

import os
import sys
sys.path.insert(0, os.environ["PROJECT_PATH"])

import matplotlib.pyplot as plt
import eval_configs.global_config as gloabl_config

import re
import numpy as np


class LegendEvaluator:

    def __init__(self, use_position=True) -> None:
        self.use_position = use_position
        self.metrics = {
            "precision": 0,
            "recall": 0,
            "f1": 0
        }

    def __call__(self, generation_code_file, golden_code_file):
        generation_texts = self._log_legends(generation_code_file)
        golden_texts = self._log_legends(golden_code_file)
        
        self._calculate_metrics(generation_texts, golden_texts)

        # redunant_file = os.environ["PROJECT_PATH"] + "/" + os.path.basename(golden_code_file).replace(".py", ".pdf")
        # os.remove(redunant_file)
        # print(self.metrics)


    def _log_legends(self, code_file):
        """
        Get legend objects of the code
        """

        with open(code_file, 'r') as f:
            lines = f.readlines()
        code = ''.join(lines)
        code = code.replace("plt.savefig", "# plt.savefig")
        prefix = self._get_prefix()
        output_file = code_file.replace(".py", ".txt")
        code_file = code_file
        suffix = self._get_suffix(output_file)
        code = prefix + code + suffix
        # print('Code:', code)
        # breakpoint()
        code_log_texts_file = code_file.replace(".py", "_log_legends.py")
        with open(code_log_texts_file, 'w') as f:
            f.write(code)
        # print('Code:', code_log_texts_file)

        os.system(f"python3 {code_log_texts_file}")
        # breakpoint()
        with open(output_file, 'r') as f:
            texts = f.read()
            print('Texts:', texts, output_file)
            texts = eval(texts)

        os.remove(code_log_texts_file)
        os.remove(output_file)
        
        # pdf_file = re.findall(r"plt\.savefig\('(.*)'\)", code)
        # if len(pdf_file) != 0:
            # pdf_file = pdf_file[0]
            # if os.path.basename(pdf_file) == pdf_file:
                # os.remove(pdf_file)

        return texts

    def _calculate_metrics(self, generation_texts: List[Tuple], golden_texts: List[Tuple]):
        """
        Calculate the metrics

        Args:
            - generation_texts: List of tuples of texts, [(x, y, x_rel, y_rel, text), ...]
            - golden_texts: List of tuples of texts, [(x, y, x_rel, y_rel, text), ...]
        """
        if len(generation_texts) == 0 or len(golden_texts) == 0:
            self.metrics["precision"] = 0
            self.metrics["recall"] = 0
            self.metrics["f1"] = 0
            return

        len_generation = len(generation_texts)
        len_golden = len(golden_texts)

        if self.use_position == False:
            generation_texts = [ t[-1] for t in generation_texts]
            golden_texts = [ t[-1] for t in golden_texts]

            n_correct = 0
            for t in golden_texts:
                if t in generation_texts:
                    n_correct += 1
                    generation_texts.remove(t)

        else:
            generation_texts = [ t[2:] for t in generation_texts]
            golden_texts = [ t[2:] for t in golden_texts]

            n_correct = 0
            for t1 in golden_texts:
                for t2 in generation_texts:
                    # text must be equal, but x_rel and y_rel can be in a range
                    if t1[-1] == t2[-1] and abs(t1[0] - t2[0]) <= 10 and abs(t1[1] - t2[1]) <= 10:
                        # print("matched:", t2)
                        n_correct += 1
                        generation_texts.remove(t2)
                        break

        self.metrics["precision"] = n_correct / len_generation
        self.metrics["recall"] = n_correct / len_golden
        if self.metrics["precision"] + self.metrics["recall"] == 0:
            self.metrics["f1"] = 0
        else:
            self.metrics["f1"] = 2 * self.metrics["precision"] * self.metrics["recall"] / (self.metrics["precision"] + self.metrics["recall"])

        return

    def _get_prefix(self):
        return f"""
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

import sys
sys.path.insert(0, '{os.environ['PROJECT_PATH']}')
import eval_configs.global_config as global_config
global_config.reset_texts()
from matplotlib.backends.backend_pdf import RendererPdf
import matplotlib.pyplot as plt

drawed_legend_texts = []
drawed_texts = []

def log_function(func):
    def wrapper(*args, **kwargs):
        global drawed_texts

        object = args[0]
        x = args[2]
        y = args[3]
        x_rel = ( x / object.width / 72 ) * 100
        y_rel = ( y / object.height / 72 ) * 100
        s = args[4]

        drawed_texts.append( (x, y, x_rel, y_rel, s) )
        return func(*args, **kwargs)

    return wrapper

RendererPdf.draw_text = log_function(RendererPdf.draw_text)
"""
    
    def _get_suffix(self, output_file):
        return f"""

all_axes = plt.gcf().get_axes()

if len(all_axes) == 0:
    all_axes = plt.gcf().get_axes()

legends = [ax.get_legend() for ax in all_axes if ax.get_legend() is not None]
if len(legends) == 0:
    legends = fig.legends

for legend in legends:
    for t in legend.get_texts():
        drawed_legend_texts.append(t.get_text())

plt.savefig('generated.pdf')

new_drawed_legend_texts = []
for t1 in drawed_legend_texts:
    for t2 in drawed_texts:
        if t1 == t2[-1]:
            new_drawed_legend_texts.append(t2)
            break
drawed_legend_texts = new_drawed_legend_texts

with open('{output_file}', 'w') as f:
    f.write(str(drawed_legend_texts))
"""


if __name__ == "__main__":
    import sys
    # sys.path.insert(0, '/home/yc21/project/Princess-s-CHI')

    evaluator = LegendEvaluator()

    generation_code_file = "/home/yc21/project/Princess-s-CHI/dataset/line/line_9.py"
    golden_code_file = "/home/yc21/project/Princess-s-CHI/results/chart2code_gpt_DirectAgent_results/direct/line_9.py"

    evaluator(generation_code_file, golden_code_file)