import re
import shap
import torch
from typing import List, Tuple


def parse_by_newlines(paragraph: str, return_offsets_mapping=True) -> Tuple[List[str], List[Tuple[int, int]]]:
    """
    Parses a paragraph into blocks based solely on two or more consecutive newlines.

    Args:
        paragraph (str): The input English paragraph.

    Returns:
        tuple[list[str], list[tuple[int, int]]]:
            - A list of the split text spans (stripped of leading/trailing whitespace).
            - A list of tuples indicating the start and end character offset
              of each span in the original paragraph. Returns ([], []) if the
              input is empty or only whitespace.
    """
    spans = []
    offset_mapping = []

    # Handle empty or whitespace-only input
    if not paragraph:
        return {'input_ids': spans, 'offset_mapping': offset_mapping}
    elif paragraph.isspace():
        return {'input_ids': [paragraph], 'offset_mapping': [(0, len(paragraph))]}

    # Keep the original for offset calculations
    original_paragraph = paragraph

    # Use regex to find split points: 2 or more newlines
    # We find the *delimiters* rather than splitting directly to better manage offsets.
    delimiters = list(re.finditer(r'\n{2,}', original_paragraph))

    search_start_index = 0
    for i, match in enumerate(delimiters):
        delimiter_start, delimiter_end = match.span()

        # Extract the raw text block *before* this delimiter
        raw_span = original_paragraph[search_start_index:delimiter_start]
        stripped_span = raw_span.strip()

        if stripped_span:
            # Find the start/end of the stripped content within the raw span's location
            strip_offset_start = 0
            while strip_offset_start < len(raw_span) and raw_span[strip_offset_start].isspace():
                strip_offset_start += 1

            strip_offset_end = len(raw_span) - 1
            while strip_offset_end >= 0 and raw_span[strip_offset_end].isspace():
                strip_offset_end -= 1

            # Check if span wasn't all whitespace
            if strip_offset_start <= strip_offset_end:
                abs_start = search_start_index + strip_offset_start
                abs_end = search_start_index + strip_offset_end + 1 # End index is exclusive
                spans.append(stripped_span)
                offset_mapping.append((abs_start, abs_end))

        # Update the start for the next block search to be *after* the current delimiter
        search_start_index = delimiter_end

    # Handle the last block (after the last delimiter, or the only block if no delimiters)
    raw_span = original_paragraph[search_start_index:]
    stripped_span = raw_span.strip()

    if stripped_span:
        # Find the start/end of the stripped content within the last raw span's location
        strip_offset_start = 0
        while strip_offset_start < len(raw_span) and raw_span[strip_offset_start].isspace():
            strip_offset_start += 1

        strip_offset_end = len(raw_span) - 1
        while strip_offset_end >= 0 and raw_span[strip_offset_end].isspace():
            strip_offset_end -= 1
        
        # Check if span wasn't all whitespace
        if strip_offset_start <= strip_offset_end:
             abs_start = search_start_index + strip_offset_start
             abs_end = search_start_index + strip_offset_end + 1 # End index is exclusive
             spans.append(stripped_span)
             offset_mapping.append((abs_start, abs_end))

    return {'input_ids': spans, 'offset_mapping': offset_mapping}


def parse_paragraph_gemini(paragraph: str, return_offsets_mapping=True) -> Tuple[List[str], List[Tuple[int, int]]]:
    """
    Splits an English paragraph into smaller spans based on delimiters.

    The splitting logic follows a priority:
    1. Double newline ('\n\n') for major blocks.
    2. Newlines followed by list markers (e.g., '1.', '*', '-') for list items.
    3. Sentence-ending punctuation ('.', '?', '!') followed by whitespace
       for standard prose.

    Args:
        paragraph (str): The input English paragraph.

    Returns:
        tuple[list[str], list[tuple[int, int]]]:
            - A list of the split text spans (stripped of leading/trailing whitespace).
            - A list of tuples indicating the start and end character offset
              of each span in the original paragraph.
    """
    if not paragraph:
        return {'input_ids': [], 'offset_mapping': []}
    elif paragraph.isspace():
        return {'input_ids': [paragraph], 'offset_mapping': [(0, len(paragraph))]}

    spans_data = []
    original_paragraph = paragraph # Keep original for offset finding
    paragraph = paragraph.strip() # Work with a stripped version for logic

    if not paragraph:
        return [], []

    # --- Helper function to find accurate offsets of stripped text ---
    def find_span_offsets(text_to_find: str, search_start: int, search_end: int) -> Tuple[int, int]:
        """Finds start/end offsets of stripped text within a range of the original."""
        try:
            # Find the first non-whitespace char of the target within the slice
            first_char_index = -1
            for i, char in enumerate(text_to_find):
                if not char.isspace():
                    first_char_index = i
                    break
            if first_char_index == -1: # Span is all whitespace
                 return (-1, -1)

            # Find the corresponding substring in the original paragraph slice
            original_slice = original_paragraph[search_start:search_end]
            relative_start = original_slice.find(text_to_find[first_char_index:])
            
            if relative_start != -1:
                 # Adjust relative start back to the actual beginning of the stripped content
                 # by skipping leading whitespace found in the original slice
                 adjusted_relative_start = relative_start
                 while adjusted_relative_start > 0 and original_slice[adjusted_relative_start -1].isspace():
                      adjusted_relative_start -=1
                 
                 # More robust: Find the first occurrence of the stripped text within the search range
                 stripped_text = text_to_find.strip()
                 if not stripped_text: return (-1,-1) # Should not happen if first_char_index != -1

                 start_offset_in_original = -1
                 temp_start = search_start
                 while temp_start < search_end:
                     found_pos = original_paragraph.find(stripped_text, temp_start, search_end)
                     if found_pos == -1:
                         break # Not found in remaining range

                     # Check if this found position is preceded only by whitespace
                     # relative to the search_start or beginning of string
                     is_valid_start = True
                     for i in range(search_start, found_pos):
                         if not original_paragraph[i].isspace():
                             is_valid_start = False
                             break
                     
                     if is_valid_start:
                         start_offset_in_original = found_pos
                         break
                     else:
                         # Continue searching after this invalid match
                         temp_start = found_pos + 1


                 if start_offset_in_original != -1:
                     end_offset_in_original = start_offset_in_original + len(stripped_text)
                     return start_offset_in_original, end_offset_in_original


            # Fallback if finding stripped version fails (should be rare)
            # Approximate based on search range - less accurate
            start = search_start + (len(original_paragraph[search_start:search_end]) - len(original_paragraph[search_start:search_end].lstrip()))
            end = start + len(text_to_find.strip())
            return start, end

        except Exception:
             # Fallback in case of unexpected error
             return (-1, -1)

    # --- Splitting Logic ---

    # Priority 1: Split by '\n\n'
    if '\n\n' in paragraph:
        delimiter = "\n\n"
        delimiter_len = len(delimiter)
        start_offset = 0
        indices = [m.start() for m in re.finditer(re.escape(delimiter), original_paragraph)]

        for index in indices:
            span_text_raw = original_paragraph[start_offset:index]
            span_text_stripped = span_text_raw.strip()
            if span_text_stripped:
                s, e = find_span_offsets(span_text_raw, start_offset, index)
                if s != -1:
                   spans_data.append((span_text_stripped, (s, e)))
            start_offset = index + delimiter_len

        # Add the last part
        span_text_raw = original_paragraph[start_offset:]
        span_text_stripped = span_text_raw.strip()
        if span_text_stripped:
            s, e = find_span_offsets(span_text_raw, start_offset, len(original_paragraph))
            if s != -1:
                spans_data.append((span_text_stripped, (s, e)))

    # Priority 2: Check for List Items (if \n\n wasn't the primary structure)
    # Use multiline flag for ^ anchor
    elif re.search(r"^\s*(\d+\.|\*|-)\s", paragraph, re.MULTILINE):
        # Split into lines, keeping original structure for offset calculation
        lines = original_paragraph.splitlines(keepends=True)
        current_span_text = ""
        current_span_start_char_index = 0
        start_char_index = 0 # Tracks position in original_paragraph

        for i, line in enumerate(lines):
            line_content_stripped = line.strip()
            is_list_item = re.match(r"^\s*(\d+\.|\*|-)\s", line_content_stripped)

            # Determine if the line starts a new block (list item or first line)
            is_new_block_start = is_list_item or (i == 0 and line_content_stripped)

            if is_new_block_start and current_span_text:
                 # Finish the previous span
                 current_span_stripped = current_span_text.strip()
                 if current_span_stripped:
                     s, e = find_span_offsets(current_span_text, current_span_start_char_index, start_char_index)
                     if s != -1:
                         spans_data.append((current_span_stripped, (s, e)))
                 # Reset for the new span (which starts with the current line)
                 current_span_text = ""
                 current_span_start_char_index = start_char_index # Start of current line


            if not current_span_text and line_content_stripped:
                # Start a new span if empty and current line has content
                current_span_start_char_index = start_char_index

            # Accumulate line to the current span
            current_span_text += line

            # Update the character index for the next line
            start_char_index += len(line)

        # Add the last accumulated span
        if current_span_text:
            current_span_stripped = current_span_text.strip()
            if current_span_stripped:
                 s, e = find_span_offsets(current_span_text, current_span_start_char_index, len(original_paragraph))
                 if s != -1:
                     spans_data.append((current_span_stripped, (s, e)))

    # Priority 3: Sentence Splitting
    else:
        # Find split points: positions *after* sentence-ending punctuation and whitespace
        # We split *after* the whitespace following the punctuation.
        split_points = [m.end() for m in re.finditer(r'[.?!]\s+', original_paragraph)]
        start_offset = 0
        for point in split_points:
            # The span ends *before* the split point if we consider the split point
            # to be the start of the *next* sentence's leading space.
            # Let's define span as text up to and including the punctuation.
            match = re.search(r'([.?!])(\s+)$', original_paragraph[start_offset:point])
            end_of_sentence_char = point
            if match:
                 # Adjust end point to be right after the punctuation mark
                 end_of_sentence_char = point - len(match.group(2)) # Subtract trailing whitespace length

            span_text_raw = original_paragraph[start_offset:end_of_sentence_char]
            span_text_stripped = span_text_raw.strip()

            if span_text_stripped:
                 # Find offsets for the stripped text within the raw span boundary
                 s, e = find_span_offsets(span_text_raw, start_offset, end_of_sentence_char)
                 if s!=-1:
                     spans_data.append((span_text_stripped, (s, e)))

            start_offset = point # Next span search starts after the whitespace

        # Add the last part (from the last split point to the end)
        span_text_raw = original_paragraph[start_offset:]
        span_text_stripped = span_text_raw.strip()
        if span_text_stripped:
            s, e = find_span_offsets(span_text_raw, start_offset, len(original_paragraph))
            if s!=-1:
                spans_data.append((span_text_stripped, (s, e)))

    # Handle case where no splits occurred but paragraph wasn't empty
    if not spans_data and paragraph:
         stripped_paragraph = original_paragraph.strip()
         s, e = find_span_offsets(original_paragraph, 0, len(original_paragraph))
         if s!=-1:
            return [stripped_paragraph], [(s, e)]
         else: # Should not happen if paragraph is not empty
             return [], []

    # Final Extraction
    spans = [text for text, offset in spans_data]
    offset_mapping = [offset for text, offset in spans_data]

    return {'input_ids': spans, 'offset_mapping': offset_mapping}


def parse_sentence(paragraph, return_offsets_mapping=True):
    """
    Parse a paragraph into spans based on delimiters and return offset mappings.
    
    Args:
        paragraph (str): The input English paragraph.
    
    Returns:
        spans (list): A list of spans split by the delimiters.
        offset_mapping (list): A list of tuples indicating the start and end position of each span.
    """
    # Regex pattern for the delimiters
    pattern = r"[.,;?!]"
    
    spans = []
    offset_mapping = []
    start = 0
    
    for match in re.finditer(pattern, paragraph):
        end = match.end()
        span = paragraph[start:end].strip()
        if span:  # Only add non-empty spans
            spans.append(span)
            offset_mapping.append((start, end))
        start = end
    
    # Add the last span if there's any text left after the final delimiter
    if start < len(paragraph):
        spans.append(paragraph[start:].strip())
        offset_mapping.append((start, len(paragraph)))
    
    if return_offsets_mapping:
        return {'input_ids': spans, 'offset_mapping': offset_mapping}
    else:
        return {'input_ids': spans}


def first_true_indices(bools, dtype=torch.long):
    """
    Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving
    the position of the first True in each "row".

    Returns the length of the rows (bools.size(-1)) if no element is True in a given row.
    """
    row_len = bools.size(-1)
    zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device)
    return torch.min(zero_or_index, dim=-1).values


def get_shap_rewards(model, query_response, tokenizer, context_length, masker=None):
    """
    Args:
        query_response: [query_responses_len]
    """
    query_text_clean = tokenizer.decode(query_response[:context_length], skip_special_tokens=True)
    response_text_clean = tokenizer.decode(query_response[context_length:], skip_special_tokens=True)

    # if torch.any(query_response[context_length:] == tokenizer.eos_token_id, dim=-1):
    #     response_text_clean += '<|endoftext|>'

    def f(x):
        inputs = []
        for _x in x:
            if len(_x) > 0 and _x[0] == " ":
                concatenated = query_text_clean + _x + "<|endoftext|>"
            else:
                concatenated = query_text_clean + " " + _x + "<|endoftext|>"
            inputs.append(concatenated)
    
        with torch.no_grad():
            input_ids = tokenizer(inputs, padding="longest", return_tensors="pt")["input_ids"].to("cuda")
            attention_mask = input_ids != tokenizer.pad_token_id
            reward_logits = model(
                input_ids=torch.masked_fill(input_ids, ~attention_mask, 0),
                attention_mask=attention_mask,
                return_dict=True,
                output_hidden_states=True
            )
            # sequence_lengths = first_true_indices(input_ids == tokenizer.pad_token_id) - 1
            # output = reward_logits[torch.arange(len(inputs), device=reward_logits.device), sequence_lengths].squeeze(-1)
            output = reward_logits.logits[:, 1] - reward_logits.logits[:, 0]

        return output.detach().cpu().float().numpy()

    masker = tokenizer if not masker else masker
    explainer = shap.Explainer(f, masker, algorithm="auto")

    shap_values = explainer([response_text_clean]) 

    return shap_values


def get_shap_rewards_imdb(model, query_response, tokenizer, context_length, masker=None):
    """
    Args:
        query_response: [query_responses_len]
    """
    query_text_clean = tokenizer.decode(query_response[:context_length], skip_special_tokens=True)
    response_text_clean = tokenizer.decode(query_response[context_length:], skip_special_tokens=True)

    def f(x):
        inputs = []
        for _x in x:
            if len(_x) == 0:
                concatenated = query_text_clean
            elif _x[0] == " ":
                concatenated = query_text_clean + _x
            else:
                concatenated = query_text_clean + " " + _x
            concatenated += tokenizer.eos_token
            inputs.append(concatenated)
    
        with torch.no_grad():
            inputs = tokenizer(inputs, padding="longest", return_tensors="pt").to("cuda")
            reward_logits = model(**inputs)
            output = reward_logits.logits[:, 1] - reward_logits.logits[:, 0]

        return output.detach().cpu().float().numpy()

    masker = tokenizer if not masker else masker
    explainer = shap.Explainer(f, masker, algorithm="auto")

    shap_values = explainer([response_text_clean]) 

    return shap_values


def get_shap_rewards_openllama(model, query_text, response_text, tokenizer, masker=None):
    """
    Args:
        query_response: [query_responses_len]
    """
    # query_text_clean = tokenizer.decode(query_response[:context_length], skip_special_tokens=True)
    # response_text_clean = tokenizer.decode(query_response[context_length:], skip_special_tokens=True)

    # if torch.any(query_response[context_length:] == tokenizer.eos_token_id, dim=-1):
    #     response_text_clean += '<|endoftext|>'

    def f(x):
        inputs = []
        for _x in x:
            if len(_x) > 0 and _x[0] == " ":
                concatenated = query_text + _x + tokenizer.eos_token
            else:
                concatenated = query_text + " " + _x + tokenizer.eos_token
            inputs.append(concatenated)
    
        with torch.no_grad():
            inputs = tokenizer(inputs, padding="longest", return_tensors="pt").to("cuda")
            output = model(**inputs).logits

        return output.detach().cpu().float().numpy()

    masker = tokenizer if not masker else masker
    explainer = shap.Explainer(f, masker, algorithm="auto")

    shap_values = explainer([response_text]) 

    return shap_values


