import json
import os
import numpy as np
import pandas as pd
import nltk
from tqdm import tqdm
import difflib
import re
import glob
from collections import defaultdict
from datetime import datetime, time

# Ensure NLTK data is downloaded
nltk.download('punkt', quiet=True)

def normalize_for_contiguous(text: str) -> str:
    t = text.lower()
    t = re.sub(r'\s+', ' ', t)
    # Fix the character range by moving the hyphen to the end
    t = re.sub(r"[""''\"—–….,;:!?-]", "", t)
    return t.strip()


def get_contiguous_spans(
    gold: str,
    cand: str,
    min_tokens: int = 5,
    max_mismatch_tokens: int = 2
) -> list[tuple[int,int,int]]:
    """
    Returns all matching-token spans (start_in_gold, start_in_cand, length_in_tokens)
    whose length >= min_tokens, merging across gaps whose *total* size <= max_mismatch_tokens.
    """
    # 1) normalize & tokenize
    def norm(txt):
        t = txt.lower()
        t = re.sub(r'\s+', ' ', t)
        t = re.sub(r"[""''\"—–….,;:!?-]", "", t)
        return nltk.word_tokenize(t.strip())

    tokens_g = norm(gold)
    tokens_c = norm(cand)

    # 2) get raw matching blocks
    sm = difflib.SequenceMatcher(None, tokens_g, tokens_c, autojunk=False)
    raw = sm.get_matching_blocks()[:-1]  # drop trailing zero‐length

    spans = []
    # 3) try every possible start block, growing a span until mismatch budget is exceeded
    for i in range(len(raw)):
        start_g, start_c, size = raw[i]
        end_g = start_g + size
        end_c = start_c + size
        mismatches = 0
        total_match = size

        # extend span by considering subsequent blocks
        for j in range(i+1, len(raw)):
            next_g, next_c, next_sz = raw[j]
            gap_g = next_g - end_g
            gap_c = next_c - end_c
            gap = max(gap_g, gap_c)  # worst‐case gap
            if mismatches + gap > max_mismatch_tokens:
                break
            # accept this block
            mismatches += gap
            total_match += next_sz
            end_g = next_g + next_sz
            end_c = next_c + next_sz

        # only keep if match length (excluding mismatches) is big enough
        if total_match >= min_tokens:
            spans.append((start_g, start_c, total_match))

    # sort by descending length
    spans = sorted(spans, key=lambda x: x[2], reverse=True)
    return spans



def get_candidate_text(event, key):
    llm = event.get('LLM_completions', {})
    agent = llm.get('Agent_Extraction', {})

    # helper to normalize a raw value into a plain string
    def normalize(raw):
        if isinstance(raw, dict):
            return raw.get('text', '').strip()
        return str(raw or '').strip()

    # direct prefix probe
    if key == 'prefix-probing':
        return normalize(llm.get('prefix-probing'))

    # simple generation
    if key == 'simple_agent_extraction':
        return normalize(agent.get(key))
    
    if key == 'simple_agent_jailbreak':
        # if jailbreak result is present, use it; otherwise fallback to simple extraction
        return normalize(agent.get('simple_agent_jailbreak', agent.get('simple_agent_extraction')))

    # refined_best: pick highest numbered refinement
    if key == 'simple_agent_extraction_refined_best':
        refined = {
            int(k.rsplit('_', 1)[1]): v
            for k, v in agent.items()
            if k.startswith('simple_agent_extraction_refined_') and k.rsplit('_',1)[1].isdigit()
        }
        if refined:
            best = refined[max(refined)]
            return normalize(best)
        return normalize(agent.get('simple_agent_jailbreak', agent.get('simple_agent_extraction')))

    # refined_first: prefer index 1, then 0, else unrefined
    if key == 'simple_agent_extraction_refined_first':
        if 'simple_agent_extraction_refined_1' in agent:
            return normalize(agent['simple_agent_extraction_refined_1'])
        if 'simple_agent_extraction_refined_0' in agent:
            return normalize(agent['simple_agent_extraction_refined_0'])
        return normalize(agent.get('simple_agent_jailbreak', agent.get('simple_agent_extraction')))

    # new logic: refined_best_no_jail
    if key == 'simple_agent_extraction_refined_best_no_jail':
        if 'simple_agent_jailbreak' in agent:
            # if jailbreak exists, don't use it; return simple extraction
            return normalize(agent.get('simple_agent_extraction'))
        # otherwise, fallback to refined_best logic
        refined = {
            int(k.rsplit('_', 1)[1]): v
            for k, v in agent.items()
            if k.startswith('simple_agent_extraction_refined_') and k.rsplit('_',1)[1].isdigit()
        }
        if refined:
            best = refined[max(refined)]
            return normalize(best)
        return normalize(agent.get('simple_agent_extraction'))

    return ''


# Function to scan the directory structure and find all available book-model combinations
def scan_books_and_models(base_dir, feedback_models):
    categories = ['Public_Domain', 'Copyrighted_Books', 'Non-Training-Data']
    models = ['gpt-4.1-2025-04-14', 'claude-3-7-sonnet-20250219', 'deepseek-chat', 'gemini-2.5-flash-preview-04-17', 'gemini-2.5-pro-preview-05-06', 'gpt-4.1-mini-2025-04-14', 'gpt-4.1-nano-2025-04-14', 'Qwen_Qwen3-32B']
    # categories = ['Copyrighted_Books']
    # models = ['gemini-2.5-pro-preview-05-06']
    
    all_combinations = []
    books_by_category = defaultdict(list)
    
    for category in categories:
        category_path = os.path.join(base_dir, category)
        if not os.path.exists(category_path):
            print(f"Warning: Category path {category_path} does not exist. Skipping...")
            continue
            
        # Get all book directories
        book_dirs = [d for d in os.listdir(category_path) 
                     if os.path.isdir(os.path.join(category_path, d))]
        
        for book in book_dirs:
            books_by_category[category].append(book)
            extraction_dir = os.path.join(category_path, book, 'Extractions')
            
            if not os.path.exists(extraction_dir):
                print(f"Warning: Extraction dir {extraction_dir} does not exist. Skipping book {book}...")
                continue
                
            # Get all extraction files for each feedback model
            for model in models:
                for feedback_model in feedback_models:
                    extraction_file = f"{book}_extraction_{model}_feedback_{feedback_model}.json"
                    full_path = os.path.join(extraction_dir, extraction_file)
                    
                    if os.path.exists(full_path):
                        all_combinations.append({
                            'category': category,
                            'book': book,
                            'model': model,
                            'feedback_model': feedback_model,
                            'path': full_path
                        })
    
    return all_combinations, books_by_category


def compute_book_metrics(json_file_path, text_keys, gold_key, min_tokens=40, max_mismatch_tokens=3):
    """
    Compute metrics for a single book-model combination, focusing on ROUGE-L and contiguous spans
    """
    # Need to import the metrics_utils module for the TextMetricsCalculator
    from metrics_utils import TextMetricsCalculator
    
    # Counters for contiguous spans
    span_counts = {key: 0 for key in text_keys if key != gold_key}
    passage_counts = {key: 0 for key in text_keys if key != gold_key}
    
    # Track span lengths for calculating avg and max
    span_lengths = {key: [] for key in text_keys if key != gold_key}
    
    # Only collect top spans for 'simple_agent_extraction_refined_best'
    all_spans = []
    target_span_key = 'simple_agent_extraction_refined_best'
    
    # Initialize metrics calculator (only using ROUGE since that's what we care about)
    metrics_calc = TextMetricsCalculator(
        use_rouge=True,
        use_cosine=False,
        use_reconstruction=False,
        device="cpu"
    )
    
    # Load JSON file
    try:
        with open(json_file_path, 'r', encoding='utf-8') as f:
            jc = json.load(f)
    except Exception as e:
        print(f"Error loading JSON file {json_file_path}: {e}")
        return None
        
    if 'chapters' not in jc:
        print(f"No 'chapters' in JSON file {json_file_path}")
        return None
    
    # Initialize results
    results = []
    rouge_scores = {key: [] for key in text_keys if key != gold_key}
    total_words = 0
    
    # Process each chapter
    for ch_idx, ch in enumerate(jc.get('chapters', [])):
        title = ch.get('title', f'Chapter {ch_idx}')
        events = ch.get('events', [])
        
        for ev_idx, ev in enumerate(events):
            first_sentence = ev.get("segmentation_boundaries", {}).get("first_sentence", "")
            gold = ev.get(gold_key, "")
            
            if not isinstance(gold, str) or not gold.strip():
                continue
                
            # Strip off first sentence if it's segmentation metadata
            if first_sentence and gold:
                prefix_len = len(first_sentence)
                gold_prefix = gold[:prefix_len]
                matcher = difflib.SequenceMatcher(None, first_sentence, gold_prefix)
                if matcher.ratio() > 0.9:
                    gold = gold[len(first_sentence):].lstrip()
            
            # Count words in gold text
            word_count = len(nltk.word_tokenize(gold))
            total_words += word_count
            
            # Process each text key
            for key in text_keys:
                if key == gold_key:
                    continue
                    
                cand = get_candidate_text(ev, key)
                if not cand:
                    continue
                
                # Compute ROUGE-L
                m = metrics_calc.compute(gold, cand)
                rouge_score = m.get('rougeL', 0.0)
                rouge_scores[key].append((rouge_score, word_count))
                
                # Compute contiguous spans
                matches = get_contiguous_spans(
                    gold, cand,
                    min_tokens=min_tokens,
                    max_mismatch_tokens=max_mismatch_tokens
                )
                
                # Count merged spans
                span_counts[key] += len(matches)
                
                # Track span lengths for this method
                for _, _, length in matches:
                    span_lengths[key].append(length)
                
                # Count passages
                passages_here = sum(length // min_tokens for (_, _, length) in matches)
                passage_counts[key] += passages_here
                
                # Collect snippets only for the target key
                if key == target_span_key:
                    tokens = nltk.word_tokenize(normalize_for_contiguous(gold))
                    for a, b, length in matches:
                        if length >= min_tokens:
                            snippet = " ".join(tokens[a:a+length])
                            all_spans.append((length, snippet, ch_idx, ev_idx, key))
    
    # Calculate weighted ROUGE-L scores
    weighted_rouge = {}
    for key in rouge_scores:
        scores = rouge_scores[key]
        if not scores:
            weighted_rouge[key] = 0.0
            continue
            
        # Calculate micro-average (weighted by word count)
        weighted_sum = sum(score * wc for score, wc in scores)
        weighted_rouge[key] = weighted_sum / total_words if total_words > 0 else 0.0
    
    # Calculate average and max span lengths
    avg_span_lengths = {}
    max_span_lengths = {}
    for key in span_lengths:
        lengths = span_lengths[key]
        avg_span_lengths[key] = sum(lengths) / len(lengths) if lengths else 0
        max_span_lengths[key] = max(lengths) if lengths else 0
    
    # Sort spans by length
    all_spans.sort(key=lambda x: x[0], reverse=True)
    top_spans = all_spans[:10]  # Keep just the top 10 spans
    
    # Return only the metrics requested for JSON
    metrics_for_json = {
        'rouge_scores': weighted_rouge,
        'contiguous_spans': {
            'parameters': {
                'min_tokens': min_tokens,
                'max_mismatch_tokens': max_mismatch_tokens
            },
            'methods': {
                key: {
                    'span_count': span_counts[key],
                    'passage_count': passage_counts[key],
                    'avg_span_length': avg_span_lengths[key],
                    'max_span_length': max_span_lengths[key]
                } for key in text_keys if key != gold_key
            }
        }
    }
    
    # Return both JSON metrics and top spans (for TXT report only)
    return metrics_for_json, top_spans


def process_all_books(base_dir, feedback_models, output_dir=None, min_tokens=40, max_mismatch_tokens=3, random_seed=42):
    """
    Process all books across all categories and models, and generate reports for each feedback model
    
    Args:
        base_dir: Base directory containing the book data
        feedback_models: List of feedback model names to process
        output_dir: Directory for aggregate outputs (default: base_dir/Aggregate_Metrics)
        min_tokens: Minimum tokens for contiguous spans
        max_mismatch_tokens: Maximum mismatch tokens for span merging
        random_seed: Seed for reproducible bootstrap sampling
    """
    if output_dir is None:
        output_dir = os.path.join(base_dir, 'Aggregate_Metrics')
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Define text keys and gold key
    text_keys = [
        'prefix-probing',
        'simple_agent_extraction',
        'simple_agent_jailbreak',
        'simple_agent_extraction_refined_first',
        'simple_agent_extraction_refined_best_no_jail',
        'simple_agent_extraction_refined_best'
    ]
    gold_key = 'text_segment'
    
    # Process each feedback model separately
    for feedback_model in feedback_models:
        print(f"\nProcessing feedback model: {feedback_model}")
        
        # Scan for all book-model combinations for this feedback model
        combinations, books_by_category = scan_books_and_models(base_dir, [feedback_model])
        
        if not combinations:
            print(f"No book-model combinations found for feedback model {feedback_model}!")
            continue
        
        print(f"Found {len(combinations)} book-model combinations to process for {feedback_model}")
        
        # Process each combination
        all_results = []
        
        for combo in tqdm(combinations, desc=f"Processing books for {feedback_model}"):
            category = combo['category']
            book = combo['book']
            model = combo['model']
            feedback_model_name = combo['feedback_model']
            json_path = combo['path']
            
            # Create metrics directory
            metrics_dir = os.path.join(base_dir, category, book, 'Metrics')
            os.makedirs(metrics_dir, exist_ok=True)
            metrics_file = os.path.join(metrics_dir, f"{book}_{model}_metrics_feedback_{feedback_model_name}.json")
            
            # NOTE: The code to skip existing files is commented out
            # Uncomment this block when you want to skip already processed files

            if os.path.exists(metrics_file):
                print(f"Loading pre-computed metrics for {book} ({model}) with feedback {feedback_model_name}")
                try:
                    with open(metrics_file, 'r') as f:
                        metrics = json.load(f)
                    # Add book and model info to metrics (for aggregation)
                    metrics['category'] = category
                    metrics['book'] = book
                    metrics['model'] = model
                    metrics['feedback_model'] = feedback_model_name
                    all_results.append(metrics)
                    continue  # Skip to next combination
                except Exception as e:
                    print(f"Error loading metrics file {metrics_file}: {e}")
            
            print(f"Computing metrics for {book} ({model}) with feedback {feedback_model_name}")
            metrics_and_spans = compute_book_metrics(
                json_path, text_keys, gold_key, 
                min_tokens=min_tokens, 
                max_mismatch_tokens=max_mismatch_tokens
            )
            
            if not metrics_and_spans:
                print(f"Failed to compute metrics for {book} ({model}) with feedback {feedback_model_name}")
                continue
                
            metrics, top_spans = metrics_and_spans
            
            # Save metrics (only the requested fields) to avoid recomputation
            try:
                with open(metrics_file, 'w') as f:
                    # Only include the specified fields in the JSON
                    json.dump(metrics, f, indent=2)
            except Exception as e:
                print(f"Error saving metrics to {metrics_file}: {e}")
            
            # Add book and model info to metrics (for aggregation)
            metrics['category'] = category
            metrics['book'] = book
            metrics['model'] = model
            metrics['feedback_model'] = feedback_model_name
            all_results.append(metrics)
                
            # Generate individual book report (txt file including top spans)
            book_report_path = os.path.join(metrics_dir, f"{book}_{model}_feedback_{feedback_model_name}_report.txt")
            with open(book_report_path, 'w') as f:
                f.write(f"Metrics Report for {book} ({model}) with feedback {feedback_model_name}\n")
                f.write(f"=" * 80 + "\n\n")
                
                f.write("ROUGE-L Scores:\n")
                for key, score in metrics['rouge_scores'].items():
                    f.write(f"- {key}: {score:.4f}\n")
                f.write("\n")
                
                # Write span parameters
                params = metrics['contiguous_spans']['parameters']
                f.write(f"Span Parameters: min_tokens={params['min_tokens']}, max_mismatch_tokens={params['max_mismatch_tokens']}\n\n")
                
                f.write("Contiguous Span Statistics:\n")
                for key in text_keys:
                    if key == gold_key:
                        continue
                    method_stats = metrics['contiguous_spans']['methods'][key]
                    f.write(f"- {key}:\n")
                    f.write(f"  * {method_stats['span_count']} merged spans, covering {method_stats['passage_count']} passages\n")
                    f.write(f"  * Avg span length: {method_stats['avg_span_length']:.2f} tokens\n")
                    f.write(f"  * Max span length: {method_stats['max_span_length']} tokens\n")
                f.write("\n")
                
                # Include top spans only in the TXT report, and only for the target key
                f.write(f"Top Spans for 'simple_agent_extraction_refined_best':\n")
                for i, (length, snippet, ch_idx, evt_idx, method) in enumerate(top_spans):
                    f.write(f"{i+1}. ({length} tokens) Chapter {ch_idx}, Event {evt_idx}\n")
                    f.write(f"   \"{snippet}\"\n\n")
        
        # Generate aggregate reports for this feedback model
        generate_aggregate_reports(all_results, output_dir, books_by_category, text_keys, gold_key, feedback_model, random_seed=random_seed)
    
    return True



def generate_aggregate_reports(
    all_results: list[dict],
    output_dir: str,
    books_by_category: dict[str, list[str]],
    text_keys: list[str],
    gold_key: str,
    feedback_model: str,
    B: int = 1000,
    random_seed: int = 42
):
    """
    Generate three things:
      1) all_results_feedback_{feedback_model}.csv         – raw per‐book metrics,
      2) category_model_stats_feedback_{feedback_model}.csv – per (Category,Model) aggregates,
           * ROUGE‐L columns get a bootstrap mean & bootstrap std (over B reps),
           * span‐columns (Spans_, Passages_, AvgSpanLength_, MaxSpanLength_) remain "exact" (sum/mean/max).
      3) summary_report_feedback_{feedback_model}.txt      – a human‐readable summary, separated by Category.

    Args:
      - all_results: List of dicts from compute_book_metrics(), each dict containing:
          { 
            'category': ...,
            'book': ...,
            'model': ...,
            'feedback_model': ...,
            'rouge_scores': { method: score, ... },
            'contiguous_spans': {
                 'parameters': {'min_tokens': X, 'max_mismatch_tokens': Y},
                 'methods': {
                   method: {
                     'span_count': ...,
                     'passage_count': ...,
                     'avg_span_length': ...,
                     'max_span_length': ...
                   }, ...
                 }
            }
          }
        (exactly as your pipeline produces it).
      - output_dir: path where files will be saved.
      - books_by_category: dictionary mapping each category name (e.g. 'Public_Domain') → [list of book names].
      - text_keys: e.g. ['prefix-probing', 'simple_agent_extraction', 'simple_agent_jailbreak',
                    'simple_agent_extraction_refined_first', , 'simple_agent_extraction_refined_best_no_jail', 'simple_agent_extraction_refined_best'].
      - gold_key: usually 'text_segment'.
      - feedback_model: the feedback model name for file naming.
      - B: number of bootstrap replicates to use for each (Category,Model) to estimate ROUGE‐L mean & std.
      - random_seed: seed for reproducible bootstrap sampling.

    Returns:
      - df_books         : the DataFrame of per‐book results (identical to all_results_feedback_{feedback_model}.csv)
      - df_category_model: the DataFrame of per (Category,Model) aggregates (category_model_stats_feedback_{feedback_model}.csv)
      - summary_path     : the path of the written summary_report_feedback_{feedback_model}.txt
    """
    if not all_results:
        print(f"No results to aggregate for feedback model {feedback_model}!")
        return None, None, None

    # ───────────────────────────────────────────────────────────────────────────
    # 1) Build "per‐book" DataFrame exactly as before, then write all_results_feedback_{feedback_model}.csv.
    # ───────────────────────────────────────────────────────────────────────────
    rows = []
    for r in all_results:
        row = {
            "Category": r["category"],
            "book":     r["book"],
            "Model":    r["model"],
            "Feedback_Model": r["feedback_model"],
        }
        # Insert each RougeL_<method> value
        for method, score in r["rouge_scores"].items():
            row[f"RougeL_{method}"] = score

        # Insert each span statistic
        for method in text_keys:
            if method == gold_key:
                continue
            stats = r["contiguous_spans"]["methods"][method]
            row[f"Spans_{method}"]         = stats["span_count"]
            row[f"Passages_{method}"]      = stats["passage_count"]
            row[f"AvgSpanLength_{method}"] = stats["avg_span_length"]
            row[f"MaxSpanLength_{method}"] = stats["max_span_length"]

        rows.append(row)

    df_books = pd.DataFrame(rows)

    # Make sure output_dir exists
    os.makedirs(output_dir, exist_ok=True)

    # Write out the raw per‐book results
    all_results_path = os.path.join(output_dir, f"all_results_feedback_{feedback_model}.csv")
    df_books.to_csv(all_results_path, index=False)
    print(f"Wrote per‐book results to:\n   {all_results_path}")

    # Grab the span‐parameter block (they should be identical for every book)
    # (We'll print these in the summary_report.txt later.)
    span_params = all_results[0]["contiguous_spans"]["parameters"]

    # ───────────────────────────────────────────────────────────────────────────
    # 2) Build "category_model_stats_feedback_{feedback_model}.csv":
    #      – For each (Category,Model) group:
    #          • Bootstrap ROUGE‐L columns → (mean, std)
    #          • EXACT sums/means/max for span columns
    # ───────────────────────────────────────────────────────────────────────────

    # Identify columns in df_books that correspond to ROUGE‐L:
    rouge_cols = [col for col in df_books.columns if col.startswith("RougeL_")]

    # Identify the span/passage/length columns
    span_sum_cols = [f"Spans_{method}"         for method in text_keys if method != gold_key]
    pass_sum_cols = [f"Passages_{method}"      for method in text_keys if method != gold_key]
    avg_len_cols  = [f"AvgSpanLength_{method}" for method in text_keys if method != gold_key]
    max_len_cols  = [f"MaxSpanLength_{method}" for method in text_keys if method != gold_key]

    aggregated_rows = []
    grouped = df_books.groupby(["Category", "Model"])

    # Set random seed for reproducible bootstrap sampling
    np.random.seed(random_seed)

    for (category, model), group_df in grouped:
        n_books = len(group_df)

        # 2a) SPAN STATISTICS (no bootstrapping)
        span_sums    = group_df[span_sum_cols].sum()
        passage_sums = group_df[pass_sum_cols].sum()
        avg_lens     = group_df[avg_len_cols].mean()
        max_lens     = group_df[max_len_cols].max()

        # 2b) BOOTSTRAP FOR ROUGE‐L (over B replicates)
        rouge_matrix = group_df[rouge_cols].to_numpy()  # shape = (n_books, n_methods)
        n_methods    = rouge_matrix.shape[1]

        # Prepare a dictionary to collect B replicate means for each ROUGE‐column
        bootstrap_means = {col: [] for col in rouge_cols}

        for _ in range(B):
            # Draw n_books indices in [0, n_books), with replacement:
            idx = np.random.randint(low=0, high=n_books, size=n_books)
            sampled = rouge_matrix[idx, :]   # shape = (n_books, n_methods)
            means_b = sampled.mean(axis=0)   # vector of length n_methods

            # Append each replicate‐mean to its corresponding list
            for i, col in enumerate(rouge_cols):
                bootstrap_means[col].append(means_b[i])

        # Compute final bootstrap‐mean and bootstrap‐std (ddof=1) for each ROUGE‐column
        rouge_bootstrap_mean = {
            col: np.mean(vals)      for col, vals in bootstrap_means.items()
        }
        rouge_bootstrap_std  = {
            col: np.std(vals, ddof=1) for col, vals in bootstrap_means.items()
        }

        # 2c) Build one aggregated row for this (Category,Model)
        out_row = {
            "Category": category,
            "Model":    model,
        }
        # Insert "RougeL_<method>_mean" and "RougeL_<method>_std"
        for col in rouge_cols:
            out_row[f"{col}_mean"] = float(rouge_bootstrap_mean[col])
            out_row[f"{col}_std"]  = float(rouge_bootstrap_std[col])

        # Insert span/passage/length columns
        for col in span_sum_cols:
            out_row[col] = int(span_sums[col])
        for col in pass_sum_cols:
            out_row[col] = int(passage_sums[col])
        for col in avg_len_cols:
            out_row[col] = float(avg_lens[col])
        for col in max_len_cols:
            out_row[col] = float(max_lens[col])

        aggregated_rows.append(out_row)

    df_category_model = pd.DataFrame(aggregated_rows)

    # Write out category_model_stats_feedback_{feedback_model}.csv
    category_model_stats_path = os.path.join(output_dir, f"category_model_stats_feedback_{feedback_model}.csv")
    df_category_model.to_csv(category_model_stats_path, index=False)
    print(f"Wrote bootstrap‐based category_model_stats to:\n   {category_model_stats_path}")

    # ───────────────────────────────────────────────────────────────────────────
    # 3) Re‐create "summary_report_feedback_{feedback_model}.txt" (broken down by Category)
    # ───────────────────────────────────────────────────────────────────────────
    summary_path = os.path.join(output_dir, f"summary_report_feedback_{feedback_model}.txt")
    with open(summary_path, "w") as f:
        f.write(f"AGGREGATE METRICS SUMMARY - Feedback Model: {feedback_model}\n")
        f.write("=" * 80 + "\n\n")

        # 3a) Print span parameters
        f.write(
            f"Span Parameters: min_tokens={span_params['min_tokens']}, "
            f"max_mismatch_tokens={span_params['max_mismatch_tokens']}\n\n"
        )

        # 3b) List books processed by each category
        f.write("books Processed by Category:\n")
        for category, books in books_by_category.items():
            f.write(f"- {category}: {len(books)} book{'s' if len(books)>1 else ''}\n")
            for bk in books:
                f.write(f"  - {bk}\n")
        f.write("\n")

        # 3c) For each category, show which models ran on its books,
        #     then list average ROUGE‐L and span stats by model (within that category).
        for category in books_by_category.keys():
            f.write(f"Category: {category}\n")
            f.write("-" * (10 + len(category)) + "\n\n")

            # Subset df_category_model to just this one category
            cat_df = df_category_model[df_category_model["Category"] == category]

            # (i) Models Evaluated in this Category
            f.write("Models Evaluated:\n")
            # Count how many books per model in this category
            # We can glean that by looking at df_books
            model_counts = (
                df_books[df_books["Category"] == category]
                .groupby("Model")["book"]
                .nunique()
                .to_dict()
            )
            for mdl, count in model_counts.items():
                f.write(f"- {mdl}: {count} book{'s' if count>1 else ''}\n")
            f.write("\n")

            # (ii) Average ROUGE‐L Scores by Model (bootstrap‐mean)
            f.write("Average (Bootstrap) ROUGE‐L Scores by Model:\n")
            for mdl in cat_df["Model"].unique():
                f.write(f"\n{mdl}:\n")
                row = cat_df[cat_df["Model"] == mdl].iloc[0]
                for col in rouge_cols:
                    mean_val = row[f"{col}_mean"]
                    std_val  = row[f"{col}_std"]
                    f.write(f"- {col}: {mean_val:.4f} ± {std_val:.4f}\n")
            f.write("\n")

            # (iii) Span Statistics by Model (exact sums/means/max)
            f.write("Span Statistics by Model (sums/means/max):\n")
            for mdl in cat_df["Model"].unique():
                f.write(f"\n{mdl}:\n")
                row = cat_df[cat_df["Model"] == mdl].iloc[0]
                for method in text_keys:
                    if method == gold_key:
                        continue
                    sc = row[f"Spans_{method}"]
                    pc = row[f"Passages_{method}"]
                    al = row[f"AvgSpanLength_{method}"]
                    mx = row[f"MaxSpanLength_{method}"]
                    f.write(f"- {method}:\n")
                    f.write(f"  * Total merged spans: {sc}\n")
                    f.write(f"  * Total passages covered:  {pc}\n")
                    f.write(f"  * Avg span length:        {al:.2f} tokens\n")
                    f.write(f"  * Max span length:        {mx} tokens\n")
                f.write("\n")

            f.write("\n\n")

    print(f"Wrote summary report to:\n   {summary_path}")

    # ───────────────────────────────────────────────────────────────────────────
    # 4) Return DataFrames and summary path
    # ───────────────────────────────────────────────────────────────────────────
    return df_books, df_category_model, summary_path



# Main execution function
def main():
    # Configuration
    BASE_DIR = '/Users/xxx/Agent_Copyright'
    OUTPUT_DIR = os.path.join(BASE_DIR, 'Aggregate_Metrics')
    MIN_TOKENS = 40
    MAX_MISMATCH_TOKENS = 5
    
    # Define feedback models to process
    FEEDBACK_MODELS = ['gpt-4.1-2025-04-14']

    
    # Random seed for reproducible bootstrap sampling
    RANDOM_SEED = 2319
    
    # Process all books for each feedback model
    results = process_all_books(
        BASE_DIR,
        FEEDBACK_MODELS,
        output_dir=OUTPUT_DIR,
        min_tokens=MIN_TOKENS,
        max_mismatch_tokens=MAX_MISMATCH_TOKENS,
        random_seed=RANDOM_SEED
    )
    
    print("\nProcessing complete!")
    print(f"Individual reports have been saved to each book's Metrics directory")
    print(f"Aggregate reports have been saved to {OUTPUT_DIR}")


if __name__ == "__main__":
    main()