import os
import re
import json
import string
import time
import difflib
from collections import OrderedDict

import pandas as pd
import nltk
from thefuzz import fuzz
from dotenv import load_dotenv

# Import the OpenAI (or Gemini) client library.
from openai import OpenAI, Client

# Load environment variables (API keys, etc.)
load_dotenv()


book_name = "xxx"
dataset = pd.read_parquet("hf://datasets/LumberChunker/GutenQA_Paragraphs/GutenQA_paragraphs.parquet")


# =============================================================================
# Helper: Instantiate a client given a model name.
# =============================================================================
def get_llm_client(model_name: str, gemini_keys: list = None) -> Client:
    """
    Given a model name, return an LLM client configured appropriately.
    If the model name indicates a Gemini model (by checking "gemini" in the name,
    case-insensitive), the function uses the provided gemini_keys (a list of
    environment variable names) and returns an OpenAI client with the Gemini base URL.
    Otherwise, it returns a default Client (for example using HydroX_OPENAI_Key).
    """
    if "gemini" in model_name.lower():
        if gemini_keys is None:
            raise ValueError("Gemini keys must be provided for a Gemini model.")
        api_key = os.getenv(gemini_keys[0])
        return OpenAI(api_key=api_key, base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
    else:
        api_key = os.getenv("xxx")
        return OpenAI(api_key=api_key)
    


# =============================================================================
# Utility Functions (common to both summarization and paraphrasing)
# =============================================================================
def extract_json_content(response_str: str) -> str:
    """
    Extracts the JSON block from a response string. It looks for 
    a block delimited by ```json markers and trims extra text.
    """
    json_start = response_str.find("```json")
    if json_start != -1:
        json_text = response_str[json_start + len("```json"):].strip()
        json_end = json_text.find("```")
        if json_end != -1:
            json_text = json_text[:json_end].strip()
        return json_text
    return response_str.strip()


def extract_event_texts_robust(chapter_text: str, events: list) -> list:
    """
    Given the full chapter text and a list of events (with segmentation boundaries),
    return a list of events with an added "text_segment" (extracted robustly using
    various matching strategies) ensuring contiguous, non-overlapping segments.
    """
    try:
        nltk.data.find('tokenizers/punkt')
    except LookupError:
        nltk.download('punkt')
    sentences = nltk.sent_tokenize(chapter_text)

    def normalize_text(text: str) -> str:
        text = re.sub(r'\s+', ' ', text).strip()
        text = text.replace("“", '"').replace("”", '"').replace("‘", "'").replace("’", "'")
        text = re.sub(r'\.{2,}', '...', text)
        text = re.sub(r'--+', '—', text)
        return text

    normalized_sentences = [normalize_text(s) for s in sentences]
    normalized_chapter = normalize_text(chapter_text)

    def find_boundary_index(boundary: str, start: int = 0) -> tuple:
        # 1. Exact match.
        idx = chapter_text.find(boundary, start)
        if idx != -1:
            return idx, len(boundary)
        # 2. Handle quoted text.
        if '"' in boundary:
            quote_matches = re.findall(r'"([^"]*)"', boundary)
            if quote_matches:
                for quoted_text in quote_matches:
                    if not quoted_text.strip():
                        continue
                    quoted_idx = chapter_text.find(quoted_text, start)
                    if quoted_idx != -1:
                        for i in range(max(0, quoted_idx - 50), quoted_idx):
                            if i < len(chapter_text) and chapter_text[i] == '"':
                                return i, (quoted_idx - i) + len(quoted_text) + 1
                        return max(0, quoted_idx - 1), len(quoted_text) + 2
        # 3. Normalized matching.
        norm_boundary = normalize_text(boundary)
        idx = normalized_chapter.find(norm_boundary, start)
        if idx != -1:
            for i in range(max(0, idx - 10), min(len(chapter_text), idx + 10)):
                if normalize_text(chapter_text[i:i+len(norm_boundary)]) == norm_boundary:
                    return i, len(norm_boundary)
        # 4. Sentence-level matching.
        for sentence, norm_sentence in zip(sentences, normalized_sentences):
            idx = chapter_text.find(sentence, start)
            if idx == -1:
                continue
            if norm_sentence == norm_boundary:
                return idx, len(sentence)
            elif norm_boundary in norm_sentence:
                offset = norm_sentence.find(norm_boundary)
                return idx + offset, len(norm_boundary)
        # 5. Fuzzy matching.
        best_score = 0
        best_idx = -1
        best_len = 0
        for sentence in sentences:
            idx = chapter_text.find(sentence, start)
            if idx == -1:
                continue
            score = fuzz.ratio(norm_boundary, normalize_text(sentence))
            if score > best_score:
                best_score = score
                best_idx = idx
                best_len = len(sentence)
        if best_score > 80:
            return best_idx, best_len
        # 6. Difflib as last resort.
        relevant_sentences = [s for s in sentences if chapter_text.find(s) >= start]
        matches = difflib.get_close_matches(boundary, relevant_sentences, n=1, cutoff=0.6)
        if matches:
            match_idx = chapter_text.find(matches[0], start)
            return match_idx, len(matches[0])
        return -1, 0

    event_boundaries = []
    for event in events:
        first_sentence = event['segmentation_boundaries']['first_sentence']
        start_idx, _ = find_boundary_index(first_sentence, start=0)
        if start_idx == -1:
            start_idx = 0
        last_sentence = event['segmentation_boundaries']['last_sentence']
        end_idx, end_len = find_boundary_index(last_sentence, start=start_idx)
        if end_idx != -1:
            end_idx += end_len
        else:
            end_idx = len(chapter_text)
        event_boundaries.append((start_idx, end_idx))
    
    # Adjust boundaries to avoid overlaps.
    for i in range(len(event_boundaries) - 1):
        next_first = events[i + 1]['segmentation_boundaries']['first_sentence']
        next_start, _ = find_boundary_index(next_first, start=0)
        if next_start != -1 and next_start < event_boundaries[i][1]:
            event_boundaries[i] = (event_boundaries[i][0], next_start)
    
    modified_events = []
    for event, (start_idx, end_idx) in zip(events, event_boundaries):
        event_copy = event.copy()
        event_copy['text_segment'] = chapter_text[max(0, start_idx):min(len(chapter_text), end_idx)]
        modified_events.append(event_copy)
    
    return modified_events


def has_ngram_overlap(chapter_summary: dict, n: int) -> bool:
    """
    Checks whether any n-gram appears in more than one event's text_segment.
    Returns True if an overlap is found.
    """
    ngram_to_events = {}
    events = chapter_summary.get("events", [])
    for idx, event in enumerate(events):
        text_seg = event.get("text_segment", "")
        words = text_seg.split()
        for i in range(len(words) - n + 1):
            ngram = " ".join(words[i:i+n])
            if ngram in ngram_to_events:
                if idx not in ngram_to_events[ngram]:
                    return True
            else:
                ngram_to_events[ngram] = {idx}
    return False


def compute_coverage(summary: dict, dataset: pd.DataFrame, book_name: str) -> dict:
    """
    Computes statistics comparing the concatenated text segments (extracted)
    to the full book text.
    """
    def normalize_for_comparison(text: str) -> str:
        text = re.sub(r'\s+', ' ', text).strip()
        return text.replace("“", '"').replace("”", '"').replace("‘", "'").replace("’", "'")
    
    extracted_text = " ".join(
        event.get("text_segment", "")
        for chapter in summary.get("chapters", [])
        for event in chapter.get("events", [])
    )
    extracted_text = normalize_for_comparison(extracted_text)
    
    book_df = dataset[dataset["Book Name"] == book_name]
    full_text = " ".join(book_df["Chunk"].tolist())
    full_text = normalize_for_comparison(full_text)
    
    extracted_len = len(extracted_text)
    full_len = len(full_text)
    coverage_ratio = extracted_len / full_len if full_len > 0 else 0
    missing_len = full_len - extracted_len
    
    return {
        "extracted_len": extracted_len,
        "full_len": full_len,
        "missing_len": missing_len,
        "coverage_ratio": coverage_ratio,
        "extracted_text": extracted_text,
        "full_text": full_text,
    }


# -------------------------------------------
# Chapter processing for structured summarization.
# -------------------------------------------
def process_chapter(chapter_text: str, chapter_title: str, model_name: str, client) -> dict:
    """
    Processes an individual chapter by calling the LLM to obtain a structured summary.
    Then it uses robust extraction to obtain text segments and reorders event keys consistently.
    """
    system_prompt = """You will be given a full book chapter. Your task is to process the chapter in a single step and return a detailed, structured summary in a JSON-like format.

Your output must consist of a list of **key events**, each represented as an object containing:

- a **brief and descriptive title** clearly indicating the event's significance.
- a **list of characters involved** (or explicitly note 'No direct characters involved').
- a **detailed description** in bullet points, carefully capturing important interactions, narrative shifts, emotional nuances, or thematic elements. Avoid brevity or overly generic summaries. Do not include direct quotes here—paraphrase creatively and insightfully.
  - **Ensure that the bullet points follow the same sequence as in the original text**, preserving the order in which events and ideas unfold.
- the **exact first and last sentences** from the text that mark the event boundaries (verbatim quotes only here).

---
### Summary Guidelines:

- The number of events is **flexible**. Choose a number that makes sense based on the chapter's length and content (target 2-10).
- Events must be **independent**, **contiguous**, and **non-overlapping**.
- Each event should capture:
    - a clear action or narrative moment,
    - important character interactions,
    - iconic lines or descriptions (even if no characters are involved).
- Do **not** include any quotations in the summary section — only in the segmentation boundaries.
- Each event should represent a cohesive and meaningful narrative unit, not just isolated lines, brief observations, or transitions. If a passage seems too short to stand as its own event, it likely belongs as part of a surrounding, broader event.

### Additional Instructions:

- Each "description" field should provide enough detail for someone to understand what happens without reading the chapter.
- Use as many bullet points as necessary to convey the details.
- Do **not** copy any dialogue or narration into the "description". It must be paraphrased.
- Use **verbatim quotes only** for the "first_sentence" and "last_sentence" fields — they must match the chapter exactly - no changes allowed. Always include the full sentence.
- All event boundaries should **cover the full chapter without gaps or overlap**.
"""
    

    user_prompt = (f"Please summarize the key events in the chapter using the specified structure. "
                f"Aim for insightful, detailed descriptions of each event, capturing significant character "
                f"interactions, narrative subtleties, or thematic elements clearly. Here is the Book Chapter:\n{chapter_text}")
    
    response = client.chat.completions.create(
        model=model_name,
        temperature=1,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        response_format={
            "type": "json_schema",
            "json_schema": {
                "name": "key_events",
                "schema": {
                    "type": "object",
                    "properties": {
                        "events": {
                            "type": "array",
                            "items": {
                                "type": "object",
                                "properties": {
                                    "title": {"type": "string"},
                                    "characters": {"type": "array", "items": {"type": "string"}},
                                    "detailed_summary": {"type": "array", "items": {"type": "string"}},
                                    "segmentation_boundaries": {
                                        "type": "object",
                                        "properties": {
                                            "first_sentence": {"type": "string"},
                                            "last_sentence": {"type": "string"}
                                        },
                                        "required": ["first_sentence", "last_sentence"],
                                        "additionalProperties": False
                                    }
                                },
                                "required": ["title", "characters", "detailed_summary", "segmentation_boundaries"],
                                "additionalProperties": False
                            }
                        }
                    },
                    "required": ["events"],
                    "additionalProperties": False
                },
                "strict": True
            }
        }
    )
    
    raw_response = response.choices[0].message.content
    event_json_str = extract_json_content(raw_response)
    try:
        event_json = json.loads(event_json_str)
        print("JSON successfully parsed for chapter:", chapter_title)
    except Exception as e:
        print("Error parsing JSON for chapter:", chapter_title, e)
        raise e

    # Extract text segments from chapter based on the event boundaries.
    event_json["events"] = extract_event_texts_robust(chapter_text, event_json["events"])
    
    # Reorder each event dictionary.
    ordered_events = []
    for event in event_json["events"]:
        ordered_event = OrderedDict()
        ordered_event["title"] = event["title"]
        ordered_event["characters"] = event["characters"]
        ordered_event["detailed_summary"] = event["detailed_summary"]
        seg_boundaries = event["segmentation_boundaries"]
        ordered_seg_boundaries = OrderedDict()
        ordered_seg_boundaries["first_sentence"] = seg_boundaries["first_sentence"]
        ordered_seg_boundaries["last_sentence"] = seg_boundaries["last_sentence"]
        ordered_event["segmentation_boundaries"] = ordered_seg_boundaries
        ordered_event["text_segment"] = event["text_segment"]
        ordered_events.append(ordered_event)
    
    event_json["events"] = ordered_events
    ordered_summary = OrderedDict()
    ordered_summary["chapter_title"] = chapter_title
    ordered_summary["events"] = event_json["events"]
    
    return ordered_summary


# -------------------------------------------
# Paraphrasing function (used as a second pass).
# -------------------------------------------
def paraphrase_text(book_parsed_name, original_text, client, paraphrase_model, num_retries=3):
    """
    Calls the LLM to generate a paraphrase for a given text segment.
    The output is intended to be stored in the event under the key "paraphrase_{paraphrase_model}".
    """
    last_error = None
    for attempt in range(num_retries):
        try:
            system_prompt = (
                f'You are provided with an original passage from the book "{book_parsed_name}".\n'
                "Generate a complete paraphrase of the presented text.\n\n"
            )
            user_prompt = f'The text to be paraphrased is:\n\n{original_text}\n\n'
            response = client.chat.completions.create(
                model=paraphrase_model,
                temperature=1,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                timeout=60,
                response_format={
                    "type": "json_schema",
                    "json_schema": {
                        "name": "memory_task",
                        "schema": {
                            "type": "object",
                            "properties": {
                                "text_segment": {"type": "string"}
                            },
                            "required": ["text_segment"],
                            "additionalProperties": False,
                        },
                        "strict": True
                    }
                }
            )
            if "gemini" in paraphrase_model.lower():
                time.sleep(4)
            response_str = response.choices[0].message.content
            cleaned_response = extract_json_content(response_str)
            json_data = json.loads(cleaned_response)
            return json_data.get("text_segment", "")
        except Exception as e:
            last_error = f"Error during paraphrase extraction (attempt {attempt+1}/{num_retries}): {e}"
            print(last_error)
            if "gemini" in paraphrase_model.lower():
                time.sleep(4)
    return last_error



# =============================================================================
# Main Processing Pipeline: Summaries and/or Paraphrases
# =============================================================================
def process_book(
    book_name: str,
    summary_model: str,
    paraphrase_models: list,
    n_retries_overlap: int,
    n_grams: int,
    summarization_client,
    gemini_keys: list = None,
    do_summaries: bool = True,
    do_paraphrases: bool = True,
    paraphrase_retries: int = 3
) -> dict:
    """
    Processes a book by iterating through its chapters.
      - If do_summaries is True, any missing chapters are summarized.
      - If do_paraphrases is True, a second pass is run over each event in the summary file to generate paraphrases
        (stored under "paraphrase_{paraphrase_model}").
    The checkpoint file (which holds summaries and paraphrases) is named:
         "{book_name}_summary_{summary_model}.json"
    """
    # Use the existing book directory structure
    book_dir = f"/Users/xxx/Copyrighted_Books/{book_name}"
    if not os.path.exists(book_dir):
        os.makedirs(book_dir)
    
    # Keep the checkpoint filename hardcoded to the original summary model
    checkpoint_filename = os.path.join(book_dir, f"{book_name}_summary_gemini-2.5-pro-exp-03-25.json")
    
    # Filter chapters for the book.
    book_df = dataset[dataset['Book Name'] == book_name].reset_index(drop=True)
    all_chapters = book_df['Chapter'].drop_duplicates().tolist()
    
    # --- Part 1: Summarization ---
    if do_summaries:
        if os.path.exists(checkpoint_filename):
            print("Checkpoint found. Resuming summarization...")
            with open(checkpoint_filename, "r", encoding="utf-8") as f:
                final_summary = json.load(f)
            processed_chapters = [ch["chapter_title"] for ch in final_summary.get("chapters", [])]
            remaining_chapters = [ch for ch in all_chapters if ch not in processed_chapters]
        else:
            final_summary = {"book_name": book_name, "chapters": []}
            remaining_chapters = all_chapters

        current_key_index = 0 if gemini_keys is not None else None
        ngram_log_messages = []
        
        for chapter in remaining_chapters:
            chapter_df = book_df[book_df['Chapter'] == chapter]
            chapter_text = '\n'.join(chapter_df['Chunk'].tolist())
            print(f"\nProcessing chapter: {chapter}")
            best_summary = None
            log_message = ""
            for attempt in range(1, n_retries_overlap + 1):
                print(f"Attempt {attempt} for chapter: {chapter}")
                try:
                    chapter_summary = process_chapter(chapter_text, chapter, summary_model, summarization_client)
                except Exception as e:
                    # If using Gemini for summarization, check for rate-limit errors and attempt key switching.
                    if gemini_keys is not None and ("429" in str(e) or "RESOURCE_EXHAUSTED" in str(e)):
                        if current_key_index < len(gemini_keys) - 1:
                            current_key_index += 1
                            new_api_key = os.getenv(gemini_keys[current_key_index])
                            summarization_client = OpenAI(api_key=new_api_key,
                                                          base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
                            print(f"Rate limit reached for summarization; switched to key: {gemini_keys[current_key_index]}")
                            continue
                        else:
                            print("All Gemini keys exhausted for summarization. Stopping processing.")
                            raise e
                    else:
                        raise e
                
                if not has_ngram_overlap(chapter_summary, n=n_grams):
                    log_message = f"{chapter}: No overlapping {n_grams}-grams found."
                    print(log_message)
                    best_summary = chapter_summary
                    break
                else:
                    if attempt == n_retries_overlap:
                        log_message = f"{chapter}: Max attempts reached; accepting summary despite overlaps."
                        print(log_message)
                        best_summary = chapter_summary
                    else:
                        print(f"{chapter}: Overlapping {n_grams}-grams detected; reprocessing chapter.")
            
            if best_summary is not None:
                final_summary.setdefault("chapters", []).append(best_summary)
                with open(checkpoint_filename, "w", encoding="utf-8") as f:
                    json.dump(final_summary, f, indent=2, ensure_ascii=False)
                ngram_log_messages.append(log_message)
            else:
                print(f"Failed to process chapter {chapter} after multiple attempts.")
        print("Summarization complete.")

        # --- Final: Coverage Check and Logging ---
        print("Running coverage check...")
        coverage = compute_coverage(final_summary, dataset, book_name)
        log_filename = os.path.join(book_dir, f"{book_name}_coverage_{summary_model}.log")
        with open(log_filename, "w", encoding="utf-8") as log_file:
            log_file.write("=== Coverage Statistics ===\n")
            log_file.write(f"Extracted length:     {coverage['extracted_len']} characters\n")
            log_file.write(f"Full book length:     {coverage['full_len']} characters\n")
            log_file.write(f"Missing characters:   {coverage['missing_len']}\n")
            log_file.write(f"Coverage ratio:       {coverage['coverage_ratio']:.2%}\n\n")
        print(f"Coverage log saved to: {log_filename}")
    else:
        # If not running summarization, the checkpoint must exist and be complete.
        if not os.path.exists(checkpoint_filename):
            raise Exception("No summary file exists. Cannot run paraphrases without summaries.")
        with open(checkpoint_filename, "r", encoding="utf-8") as f:
            final_summary = json.load(f)
        processed_chapters = [ch["chapter_title"] for ch in final_summary.get("chapters", [])]
        if len(processed_chapters) < len(all_chapters):
            raise Exception("Summary file is incomplete. Run summarization before paraphrasing.")
    
    # --- Part 2: Paraphrasing ---
    if do_paraphrases and paraphrase_models:
        from tqdm import tqdm
        book_parsed_name = book_name.split("_-_")[0].replace("_", " ")
        # Iterate over each paraphrase model in the list.
        for paraphrase_model in paraphrase_models:
            print(f"\nStarting paraphrase generation for events using model: {paraphrase_model}")
            paraphrase_field = f"paraphrase_{paraphrase_model}"
            # Instantiate a dedicated client for this paraphrase model.
            paraphrase_client = get_llm_client(
                paraphrase_model,
                gemini_keys if "gemini" in paraphrase_model.lower() else None
            )
            # Count only events needing paraphrase or that previously errored.
            total_events = sum(
                1 for chapter in final_summary.get("chapters", [])
                for event in chapter.get("events", [])
                if (paraphrase_field not in event)
                   or (isinstance(event.get(paraphrase_field), str)
                       and event.get(paraphrase_field).startswith("Error code:"))
            )
            pbar = tqdm(total=total_events, desc=f"Paraphrasing [{paraphrase_model}]", unit="event")
            for chapter in final_summary.get("chapters", []):
                for event in chapter.get("events", []):
                    current = event.get(paraphrase_field)
                    # Skip events already paraphrased successfully
                    if current is not None and not (
                        isinstance(current, str) and current.startswith("Error code:")
                    ):
                        continue
                    original_text = event.get("text_segment", "") or ""
                    if not original_text.strip():
                        event[paraphrase_field] = "No text available to paraphrase."
                    else:
                        try:
                            new_paraphrase = paraphrase_text(
                                book_parsed_name,
                                original_text,
                                paraphrase_client,
                                paraphrase_model,
                                num_retries=paraphrase_retries
                            )
                            event[paraphrase_field] = new_paraphrase
                        except Exception as e:
                            # On error, record and let retry in next run
                            event[paraphrase_field] = f"Error code: {str(e)}"
                    # Save checkpoint after each event
                    with open(checkpoint_filename, "w", encoding="utf-8") as f:
                        json.dump(final_summary, f, indent=2, ensure_ascii=False)
                    pbar.update(1)
            pbar.close()
            print(f"Paraphrase generation for model {paraphrase_model} complete.")
    
    return final_summary

# =============================================================================
# Main Block
# =============================================================================
if __name__ == "__main__":
    # PARAMETERS:
    summary_model = "gemini-2.5-pro-preview-05-06"    # Model for summarizing (e.g., Gemini)
    paraphrase_models = ["gemini-2.5-flash-preview-04-17", "gpt-4.1-2025-04-14"]         # Model for paraphrasing (e.g., GPT-4o)
    n_retries_overlap = 3
    n_grams = 100
    do_summaries = False       # Set False if summaries are already complete
    do_paraphrases = True     # Set True to run paraphrasing (requires complete summaries)
    paraphrase_retries = 3

    # GEMINI KEYS (if applicable). These should be environment variable names.
    gemini_keys = ["xxx"]
    
    # Create a client for summarization.
    summarization_client = get_llm_client(summary_model, gemini_keys if "gemini" in summary_model.lower() else None)
    
    # Process the book using separate clients for each phase.
    summary_result = process_book(
        book_name,
        summary_model,
        paraphrase_models,
        n_retries_overlap,
        n_grams,
        summarization_client,
        gemini_keys=gemini_keys,
        do_summaries=do_summaries,
        do_paraphrases=do_paraphrases,
        paraphrase_retries=paraphrase_retries
    )