{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXX-3/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See XXXX\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import argparse\n",
    "from pathlib import Path\n",
    "import re\n",
    "import random\n",
    "from typing import Dict, List, Tuple, Set\n",
    "from datasets import load_from_disk\n",
    "import pandas as pd\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_dataset_params(dataset_name: str) -> Tuple[str, str, str, str]:\n",
    "    \"\"\"Extract parameters from dataset name.\"\"\"\n",
    "    parts = dataset_name.split('_')\n",
    "    base_name = parts[0]\n",
    "    source = parts[1]\n",
    "    cutoff = parts[2]\n",
    "    lookback = parts[3]\n",
    "    \n",
    "    return base_name, source, cutoff, lookback\n",
    "\n",
    "def load_all_datasets(directory: str) -> Dict[str, any]:\n",
    "    \"\"\"Load all Hugging Face datasets from the given directory.\"\"\"\n",
    "    datasets = {}\n",
    "    \n",
    "    for path in Path(directory).glob(\"*\"):\n",
    "        if path.is_dir():\n",
    "            try:\n",
    "                dataset_name = path.name\n",
    "                print(f\"Loading dataset: {dataset_name}\")\n",
    "                dataset = load_from_disk(str(path))\n",
    "                datasets[dataset_name] = dataset\n",
    "                print(f\"  - Loaded {len(dataset)} entries\")\n",
    "            except Exception as e:\n",
    "                print(f\"  - Error loading dataset {dataset_name}: {str(e)}\")\n",
    "    \n",
    "    return datasets\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'args' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[3], line 5\u001b[0m\n\u001b[1;32m      2\u001b[0m output\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mretrieval_comparison_results.csv\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m      3\u001b[0m samples\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLoading datasets from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[43margs\u001b[49m\u001b[38;5;241m.\u001b[39mdirectory\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m      6\u001b[0m datasets \u001b[38;5;241m=\u001b[39m load_all_datasets(args\u001b[38;5;241m.\u001b[39mdirectory)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'args' is not defined"
     ]
    }
   ],
   "source": [
    "directory=\"/is/cluster/fast/XXXX-11/forecasting/news/retrieval\"\n",
    "output=\"retrieval_comparison_results.csv\"\n",
    "samples=3\n",
    "\n",
    "print(f\"Loading datasets from {directory}\")\n",
    "datasets = load_all_datasets(directory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_article_identifier(article: Dict) -> str:\n",
    "    \"\"\"Create a unique identifier for an article based on its URL and title.\"\"\"\n",
    "    url = article.get('url', '')\n",
    "    title = article.get('title', '')\n",
    "    return f\"{url}_{title}\"\n",
    "\n",
    "def compare_article_lists(list1: List[Dict], list2: List[Dict]) -> Tuple[int, Set[str], Set[str], float]:\n",
    "    \"\"\"\n",
    "    Compare two lists of articles and return information about differences.\n",
    "    \n",
    "    Returns:\n",
    "        - Number of common articles\n",
    "        - Set of article IDs unique to list1\n",
    "        - Set of article IDs unique to list2\n",
    "        - Average position change for common articles\n",
    "    \"\"\"\n",
    "    # Create sets of article identifiers and position mappings\n",
    "    ids1 = {get_article_identifier(article) for article in list1}\n",
    "    ids2 = {get_article_identifier(article) for article in list2}\n",
    "    \n",
    "    # Create position mappings (article_id -> position)\n",
    "    pos_map1 = {get_article_identifier(article): i for i, article in enumerate(list1)}\n",
    "    pos_map2 = {get_article_identifier(article): i for i, article in enumerate(list2)}\n",
    "    \n",
    "    # Find unique articles\n",
    "    unique_to_list1 = ids1 - ids2\n",
    "    unique_to_list2 = ids2 - ids1\n",
    "    \n",
    "    # Calculate position changes for common articles\n",
    "    common_articles = ids1.intersection(ids2)\n",
    "    position_changes = []\n",
    "    \n",
    "    for article_id in common_articles:\n",
    "        pos1 = pos_map1[article_id]\n",
    "        pos2 = pos_map2[article_id]\n",
    "        position_changes.append(abs(pos1 - pos2))\n",
    "    \n",
    "    # Calculate average position change\n",
    "    avg_position_change = sum(position_changes) / len(position_changes) if position_changes else 0\n",
    "    \n",
    "    return len(common_articles), unique_to_list1, unique_to_list2, avg_position_change\n",
    "\n",
    "def run_comparison(datasets: Dict[str, any]) -> pd.DataFrame:\n",
    "    \"\"\"Run comparison between all dataset pairs and return a DataFrame with results.\"\"\"\n",
    "    results = []\n",
    "    \n",
    "    # Get all dataset names\n",
    "    dataset_names = list(datasets.keys())\n",
    "    \n",
    "    # Create all possible pairs\n",
    "    for i in range(len(dataset_names)):\n",
    "        for j in range(i+1, len(dataset_names)):\n",
    "            name1 = dataset_names[i]\n",
    "            name2 = dataset_names[j]\n",
    "            \n",
    "            ds1 = datasets[name1]\n",
    "            ds2 = datasets[name2]\n",
    "            \n",
    "            # Extract parameters\n",
    "            base1, source1, cutoff1, lookback1 = extract_dataset_params(name1)\n",
    "            base2, source2, cutoff2, lookback2 = extract_dataset_params(name2)\n",
    "            \n",
    "            # Only compare datasets with same base name and source\n",
    "            if base1 != base2 or source1 != source2:\n",
    "                continue\n",
    "                \n",
    "            # If datasets have different lengths, use the smaller one\n",
    "            common_length = min(len(ds1), len(ds2))\n",
    "            \n",
    "            print(f\"Comparing {name1} vs {name2}\")\n",
    "            \n",
    "            # Track statistics\n",
    "            total_questions = 0\n",
    "            questions_with_different_articles = 0\n",
    "            questions_with_reranking = 0\n",
    "            total_reranked_articles = 0\n",
    "            sum_position_changes = 0\n",
    "            \n",
    "            # Store sample questions with differences\n",
    "            different_article_samples = []\n",
    "            \n",
    "            # Iterate through questions\n",
    "            for q_idx in tqdm(range(common_length), desc=\"Comparing questions\"):\n",
    "                if \"retrieved_articles\" not in ds1[q_idx] or \"retrieved_articles\" not in ds2[q_idx]:\n",
    "                    continue\n",
    "                    \n",
    "                articles1 = ds1[q_idx][\"retrieved_articles\"]\n",
    "                articles2 = ds2[q_idx][\"retrieved_articles\"]\n",
    "                \n",
    "                # Skip if either has no articles\n",
    "                if not articles1 or not articles2:\n",
    "                    continue\n",
    "                \n",
    "                total_questions += 1\n",
    "                \n",
    "                # Compare article lists\n",
    "                common_count, unique_to_1, unique_to_2, avg_position_change = compare_article_lists(articles1, articles2)\n",
    "                \n",
    "                # Update statistics\n",
    "                if unique_to_1 or unique_to_2:\n",
    "                    questions_with_different_articles += 1\n",
    "                    \n",
    "                    # Collect sample if interesting (has differences)\n",
    "                    if len(different_article_samples) < 10:\n",
    "                        different_article_samples.append({\n",
    "                            \"q_idx\": q_idx,\n",
    "                            \"question\": ds1[q_idx].get(\"question\", \"\"),\n",
    "                            \"unique_to_1\": list(unique_to_1),\n",
    "                            \"unique_to_2\": list(unique_to_2),\n",
    "                            \"articles1\": articles1,\n",
    "                            \"articles2\": articles2\n",
    "                        })\n",
    "                \n",
    "                if common_count > 0 and avg_position_change > 0:\n",
    "                    questions_with_reranking += 1\n",
    "                    total_reranked_articles += common_count\n",
    "                    sum_position_changes += avg_position_change\n",
    "            \n",
    "            # Calculate average position change across all questions\n",
    "            avg_position_change_overall = sum_position_changes / questions_with_reranking if questions_with_reranking > 0 else 0\n",
    "            \n",
    "            # Add comparison results to results list\n",
    "            results.append({\n",
    "                \"dataset1\": name1,\n",
    "                \"dataset2\": name2,\n",
    "                \"cutoff1\": cutoff1,\n",
    "                \"cutoff2\": cutoff2,\n",
    "                \"lookback1\": lookback1,\n",
    "                \"lookback2\": lookback2,\n",
    "                \"total_questions\": total_questions,\n",
    "                \"questions_with_different_articles\": questions_with_different_articles,\n",
    "                \"questions_with_different_articles_pct\": questions_with_different_articles / total_questions * 100 if total_questions > 0 else 0,\n",
    "                \"questions_with_reranking\": questions_with_reranking,\n",
    "                \"questions_with_reranking_pct\": questions_with_reranking / total_questions * 100 if total_questions > 0 else 0,\n",
    "                \"total_reranked_articles\": total_reranked_articles,\n",
    "                \"avg_position_change\": avg_position_change_overall,\n",
    "                \"samples\": different_article_samples\n",
    "            })\n",
    "    \n",
    "    return pd.DataFrame(results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "minir1",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
