import stanza
import nltk
from nltk.tree import Tree

# It's good practice to initialize the Stanza pipeline once
# and pass it to the function if you're calling it multiple times.
NLP_PIPELINE = None


def initialize_stanza(lang='en', processors='tokenize,pos,constituency'):
    """Initializes and returns the Stanza pipeline."""
    global NLP_PIPELINE
    if NLP_PIPELINE is None or NLP_PIPELINE.lang != lang or NLP_PIPELINE.processors_str != processors:
        print(f"Initializing Stanza pipeline for {lang} with {processors}...")
        try:
            stanza.download(lang=lang, processors=processors, verbose=False)
            NLP_PIPELINE = stanza.Pipeline(lang=lang, processors=processors, verbose=False, download_method=None)
            print("Stanza pipeline initialized.")
        except Exception as e:
            print(f"Error initializing Stanza pipeline: {e}")
            NLP_PIPELINE = None # Ensure it's None if failed
            raise
    return NLP_PIPELINE


def sentence_tokenizer(sentence_text: str, target_nltk_height: int, nlp: stanza.Pipeline) -> dict:
    """
    Parses a sentence into spans based on a specified NLTK tree height.

    Args:
        sentence_text: The input sentence string.
        target_nltk_height: The desired NLTK height of subtrees to be extracted as spans.
                            - NLTK height of a pre-terminal (e.g., (DT The)) is 2.
                              Using target_nltk_height = 2 will generally result in token-level spans.
                            - NLTK height of a phrase whose children are only pre-terminals
                              (e.g., (NP (DT a) (NN dog))) is 3.
                              Using target_nltk_height = 3 will group words under such "flat" phrases.
                            - Must be an integer >= 2.
        nlp: An initialized Stanza Pipeline object.

    Returns:
        A dictionary with:
        - 'input_ids': A list of strings, where each string is a span.
        - 'offset_mapping': A list of (start_char, end_char) tuples for each span,
                            where end_char is exclusive.
    """
    if not sentence_text.strip():
        return {'input_ids': [], 'offset_mapping': []}
    if not nlp:
        raise ValueError("Stanza pipeline (nlp) not initialized or provided.")
    if not isinstance(target_nltk_height, int) or target_nltk_height < 2:
        raise ValueError("target_nltk_height must be an integer >= 2.")

    doc = nlp(sentence_text)
    if not doc.sentences:
        return {'input_ids': [], 'offset_mapping': []}

    # For this function, we process the first sentence.
    # It could be extended to handle multiple sentences in a document.
    stz_sentence = doc.sentences[0]

    if not stz_sentence.constituency:
        # Fallback to basic tokenization if no constituency tree is available for some reason
        print("Warning: No constituency tree found in Stanza sentence. Falling back to tokenization.")
        spans = [token.text for token in stz_sentence.tokens]
        offsets = [(token.start_char, token.end_char) for token in stz_sentence.tokens]
        return {'input_ids': spans, 'offset_mapping': offsets}

    try:
        # The Stanza constituency output might be a simple string if parsing failed,
        # or a LISP-style tree string.
        if not str(stz_sentence.constituency).startswith("("): # Basic check for tree structure
             raise ValueError("Constituency output does not look like a tree.")
        nltk_tree = Tree.fromstring(str(stz_sentence.constituency))
    except ValueError as e:
        print(f"Warning: Could not parse constituency tree string: '{stz_sentence.constituency}'. Error: {e}. Falling back to tokenization.")
        spans = [token.text for token in stz_sentence.tokens]
        offsets = [(token.start_char, token.end_char) for token in stz_sentence.tokens]
        return {'input_ids': spans, 'offset_mapping': offsets}


    stz_tokens = stz_sentence.tokens
    collected_spans_info = [] # Will store dicts of {'text': ..., 'start_char': ..., 'end_char': ...}

    # Helper recursive function to find spans
    def _find_spans_recursive(current_node, current_leaf_idx):
        """
        Traverses the tree, collects spans, and returns the updated leaf index.
        current_leaf_idx is the index in stz_tokens that the first leaf of current_node corresponds to.
        """
        if not isinstance(current_node, Tree):
            # This case should ideally not be reached if called with Tree objects.
            # If current_node is a leaf string, its parent (pre-terminal) should handle it.
            return current_leaf_idx

        node_height = current_node.height()
        take_this_node_as_span = False

        if node_height == target_nltk_height:
            take_this_node_as_span = True
        elif node_height < target_nltk_height and node_height >= 2: # NLTK Height 2 is pre-terminal.
            # This node is "flatter" than the target. Take it as is to ensure full coverage.
            take_this_node_as_span = True
        
        if take_this_node_as_span:
            span_leaves = current_node.leaves()
            num_leaves_in_span = len(span_leaves)

            if num_leaves_in_span == 0: # Should not happen for valid subtrees from constituency parser
                return current_leaf_idx
            
            # Boundary check for token indices
            start_token_idx_for_span = current_leaf_idx
            end_token_idx_for_span = current_leaf_idx + num_leaves_in_span - 1

            if not (0 <= start_token_idx_for_span < len(stz_tokens) and \
                    0 <= end_token_idx_for_span < len(stz_tokens) and \
                    start_token_idx_for_span <= end_token_idx_for_span):
                print(f"Warning: Span token indices [{start_token_idx_for_span}-{end_token_idx_for_span}] "
                      f"out of bounds (total tokens: {len(stz_tokens)}). Span leaves: '{' '.join(span_leaves)}'. Skipping.")
                # Still need to advance the leaf index by the number of leaves this node claims
                return current_leaf_idx + num_leaves_in_span

            span_text = " ".join(span_leaves)
            span_start_char = stz_tokens[start_token_idx_for_span].start_char
            # Stanza's token.end_char is already exclusive for slicing.
            span_end_char = stz_tokens[end_token_idx_for_span].end_char 

            collected_spans_info.append({
                'text': span_text,
                'start_char': span_start_char,
                'end_char': span_end_char
            })
            return current_leaf_idx + num_leaves_in_span
        
        elif node_height > target_nltk_height: # Recurse on children
            updated_leaf_idx = current_leaf_idx
            for child in current_node:
                if isinstance(child, Tree):
                    updated_leaf_idx = _find_spans_recursive(child, updated_leaf_idx)
                else:
                    # Child is a leaf string. This means 'current_node' is a pre-terminal (height 2).
                    # This path is taken if node_height (2) > target_nltk_height (e.g., target_nltk_height = 1, which is invalid for this func).
                    # Or, if a token somehow isn't captured by its pre-terminal being taken as a span.
                    # This case should ideally be covered by pre-terminals (H=2) being caught by `take_this_node_as_span`.
                    # However, to be safe, if we encounter a direct leaf string here, we treat it as a single token span.
                    if 0 <= updated_leaf_idx < len(stz_tokens):
                         collected_spans_info.append({
                            'text': child,
                            'start_char': stz_tokens[updated_leaf_idx].start_char,
                            'end_char': stz_tokens[updated_leaf_idx].end_char
                        })
                    else:
                        print(f"Warning: Leaf index {updated_leaf_idx} out of bounds for single token '{child}'.")
                    updated_leaf_idx += 1 # Advance for this single leaf
            return updated_leaf_idx
        else:
            # This case implies node_height < 2 (e.g., 1 for Tree('A', [])).
            # Standard constituency trees end in pre-terminals (H=2) with string children.
            # If we reach here, it means the node is too small and wasn't processed.
            # We should advance the leaf counter by its leaves to maintain consistency.
            return current_leaf_idx + len(current_node.leaves())

    # Start the recursive processing from the root of the parsed tree
    if isinstance(nltk_tree, Tree):
        _find_spans_recursive(nltk_tree, 0) # Initial leaf index is 0

    final_spans = [info['text'] for info in collected_spans_info]
    final_offsets = [(info['start_char'], info['end_char']) for info in collected_spans_info]

    return {'input_ids': final_spans, 'offset_mapping': final_offsets}