import numpy as np
# from fld.metrics.PrecisionRecall import PrecisionRecall
import torch
import xml.etree.ElementTree as ET

from pe.callback.callback import Callback
from pe.metric_item import FloatMetricItem
from pe.logging import execution_logger
from pe.constant.data import TEXT_DATA_COLUMN_NAME


class ComputeFormatMatch(Callback):
    """The callback that computes format match ratio of the synthetic data."""

    def __init__(self, format_type='archehrqa'):
        """Constructor.

        :param format_type: The type of format to check for the task, defaults to xml format of 'archehrqa'
        :type format_type: str, optional
        """
        self._format_type = format_type
        self._format_metric_name = f"format_match_{self._format_type}"

    def _check_format(self, text):
        if self._format_type == 'archehrqa':
            try:
                ET.fromstring(text)
                root = ET.fromstring(text)
                print(f"{root.tag=}, {root.attrib=}, {len(root)=}")
                if root.tag != "case":
                    return False
                required_tags = [
                    "patient_narrative",
                    "patient_question",
                    "clinician_question",
                    "note_excerpt"
                ]
                for tag in required_tags:
                    if root.find(tag) is None:
                        print(f"Failed at missing tag: {tag}")
                        return False
                pq = root.find("patient_question")
                if pq is None:
                    print(F"Failed at missing tag (2): patient_question")
                    return False
                phrases = pq.findall("phrase")
                if not phrases:
                    print(f"Failed at missing phrases in patient_question")
                    return False
                for phrase in phrases:
                    if "id" not in phrase.attrib or "start_char_index" not in phrase.attrib:
                        print(f"Failed at missing attributes in phrase: {phrase.attrib=}, but should have 'id' and 'start_char_index' in it")
                        return False
                print(f"Passed XML format check!")
                return True
            except ET.ParseError:
                print(f"Failed at XML parsing")
                return False
        elif self._format_type == 'archehrqa2':
            try:
                ET.fromstring(text)
                root = ET.fromstring(text)
                print(f"{root.tag=}, {root.attrib=}, {len(root)=}")
                if root.tag != "case":
                    return False
                required_tags = [
                    "introduction",
                    "patient_narrative",
                    "patient_question",
                    "clinician_question",
                    "note_excerpt"
                ]
                for tag in required_tags:
                    if root.find(tag) is None:
                        print(f"Failed at missing tag: {tag}")
                        return False
                intro = root.find("introduction")
                if intro.text is not None:
                    if intro.text.strip() != "This XML includes a patient narrative, key phrases identifying focal points of the question, one related rephrased clinician question, and supporting evidence.":
                        print(f"Failed at missing or incorrect introduction tag, intro.text.strip()=**{intro.text.strip()}**")
                        return False
                else:
                    print(F"Failed at missing text in <introduction>")
                    return False
                pq = root.find("patient_question")
                if pq is None:
                    print(F"Failed at missing tag (2): patient_question")
                    return False
                phrases = pq.findall("phrase")
                if not phrases:
                    print(f"Failed at missing phrases in patient_question")
                    return False
                for phrase in phrases:
                    if "id" not in phrase.attrib or "start_char_index" not in phrase.attrib:
                        print(f"Failed at missing attributes in phrase: {phrase.attrib=}, but should have 'id' and 'start_char_index' in it")
                        return False
                print(f"Passed XML format check!")
                return True
            except ET.ParseError:
                print(f"Failed at XML parsing")
                return False
        
        # Add more format checks as needed
        return False

    def __call__(self, syn_data):
        """This function is called after each PE iteration that computes the format match between the private and synthetic
        data.

        :param syn_data: The synthetic data
        :type syn_data: :py:class:`pe.data.Data`
        :return: The format match between the private and synthetic data
        :rtype: list[:py:class:`pe.metric_item.FloatMetricItem`]
        """
        execution_logger.info(
            f"Computing format match ({self._format_type})"
        )
        texts = syn_data.data_frame[TEXT_DATA_COLUMN_NAME].values
        print(f"{texts=}, {len(texts)=}, {type(texts)=}")
        if len(texts) == 0:
            format_match = 0.0
        else:
            matches = [self._check_format(text) for text in texts]
            format_match = sum(matches) / len(texts)
        
        format_metric_item = FloatMetricItem(name=self._format_metric_name, value=format_match)
        execution_logger.info(
            f"Finished computing format match ({self._format_type}), {format_metric_item.value=}"
        )
        return [format_metric_item]
