#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Evaluation data preparation script that:
1. Loads annotations_train.json (train data, dropping annotations)
2. Loads test_4.json (already complete with questions and answers)
3. Downloads CORAL dataset from HuggingFace and joins with coral_test.json
4. Downloads dc767 dataset from GitHub and joins with dc767_test.json  
5. Merges all four datasets into evaluation.json with train data first
"""

import json
import os
import sys
import csv
import urllib.request
from pathlib import Path
from typing import Dict, List, Any


def download_coral_dataset(commit_id: str = "e6b8fb6e5d6300faed70984eb445eb7ef13c0056"):
    """Download CORAL dataset from HuggingFace"""
    try:
        from datasets import load_dataset
    except ImportError:
        print("Error: datasets library not installed. Installing...")
        os.system(f"{sys.executable} -m pip install datasets")
        from datasets import load_dataset
    
    print(f"Downloading CORAL dataset from HuggingFace (commit: {commit_id})...")
    
    # Load the dataset with specific commit
    dataset = load_dataset(
        "ariya2357/CORAL",
        split="train",
        revision=commit_id
    )
    
    print(f"Downloaded {len(dataset)} entries")
    
    # Convert to the expected format
    train_array = []
    for item in dataset:
        train_array.append(item)
    
    return train_array


def download_dc767_csv(csv_url: str, output_path: str) -> List[Dict]:
    """Download dc767 CSV from GitHub and convert to list of dictionaries"""
    print(f"\nDownloading dc767 dataset from GitHub...")
    print(f"URL: {csv_url}")
    
    # Download the CSV file
    urllib.request.urlretrieve(csv_url, output_path)
    print(f"Downloaded CSV to: {output_path}")
    
    # Read CSV and convert to list of dictionaries
    data_list = []
    with open(output_path, 'r', encoding='utf-8') as f:
        csv_reader = csv.DictReader(f)
        for row in csv_reader:
            data_list.append(row)
    
    print(f"Loaded {len(data_list)} entries from CSV")
    return data_list


def join_coral_data(coral_test_data: List[Dict], train_array_data: List[Dict]) -> List[Dict]:
    """Join coral_test data with downloaded CORAL data"""
    
    print(f"\nProcessing {len(coral_test_data)} coral test entries...")
    
    # Create conv_id to data mapping
    conv_id_to_data = {}
    for entry in train_array_data:
        conv_id = entry.get('conv_id', '')
        if conv_id:
            conv_id_to_data[conv_id] = entry
    
    print(f"Created mapping for {len(conv_id_to_data)} conv_ids")
    
    # Join the data
    joined_data = []
    matched_count = 0
    unmatched_conv_ids = []
    
    for test_item in coral_test_data:
        conv_id = test_item['item_name']  # item_name contains conv_id
        
        if conv_id in conv_id_to_data:
            coral_data = conv_id_to_data[conv_id]
            
            # Extract question and response from first turn
            if 'turns' in coral_data and len(coral_data['turns']) > 0:
                first_turn = coral_data['turns'][0]
                question = first_turn.get('question', '')
                gt_answer = first_turn.get('response', '')  # gt_answer comes from response field
                
                # Create the full entry
                full_entry = {
                    'item_name': conv_id,
                    'dataset_name': test_item['dataset_name'],
                    'question': question,
                    'gt_answer': gt_answer,
                    'gen_answer': test_item['gen_answer']
                }
                
                joined_data.append(full_entry)
                matched_count += 1
            else:
                print(f"Warning: No turns found for conv_id: {conv_id}")
                unmatched_conv_ids.append(conv_id)
        else:
            unmatched_conv_ids.append(conv_id)
    
    print(f"CORAL join results: {matched_count} successfully joined, {len(unmatched_conv_ids)} unmatched")
    
    if unmatched_conv_ids:
        print(f"First 5 unmatched CORAL conv_ids: {unmatched_conv_ids[:5]}")
    
    return joined_data


def join_dc767_data(dc767_test_data: List[Dict], csv_data: List[Dict]) -> List[Dict]:
    """Join dc767_test data with downloaded CSV data"""
    
    print(f"\nProcessing {len(dc767_test_data)} dc767 test entries...")
    
    # Join the data
    joined_data = []
    matched_count = 0
    unmatched_indices = []
    
    for test_item in dc767_test_data:
        # item_name is now a string, convert to int for indexing
        try:
            row_index = int(test_item['item_name'])
        except (ValueError, TypeError):
            print(f"Warning: Invalid item_name '{test_item['item_name']}' - skipping")
            unmatched_indices.append(test_item['item_name'])
            continue
        
        if 0 <= row_index < len(csv_data):
            csv_row = csv_data[row_index]
            
            # Extract question and answer from CSV
            question = csv_row.get('query', '').strip()  # CSV uses 'query' instead of 'question'
            gt_answer = csv_row.get('answer', '').strip()  # CSV uses 'answer' field
            
            # Create the full entry
            full_entry = {
                'item_name': test_item['item_name'],  # Keep as string to match input format
                'dataset_name': test_item['dataset_name'],
                'question': question,
                'gt_answer': gt_answer,
                'gen_answer': test_item['gen_answer']
            }
            
            joined_data.append(full_entry)
            matched_count += 1
        else:
            unmatched_indices.append(test_item['item_name'])
    
    print(f"dc767 join results: {matched_count} successfully joined, {len(unmatched_indices)} unmatched")
    
    if unmatched_indices:
        print(f"First 5 unmatched dc767 indices: {unmatched_indices[:5]}")
    
    return joined_data


def merge_and_save(train_data: List[Dict], test_4_data: List[Dict], coral_data: List[Dict], dc767_data: List[Dict], output_path: Path, sort_data: bool = False) -> None:
    """Merge the datasets and save to output file"""
    
    # Merge the data in the specified order: train first, then test_4, then coral, then dc767
    merged_data = train_data + test_4_data + coral_data + dc767_data
    
    print(f"\nMerged {len(merged_data)} total items")
    
    # Count items by dataset
    dataset_counts = {}
    for item in merged_data:
        dataset = item.get('dataset_name', 'unknown')
        dataset_counts[dataset] = dataset_counts.get(dataset, 0) + 1
    
    print("Dataset distribution:")
    for dataset, count in sorted(dataset_counts.items()):
        print(f"  {dataset}: {count} items")
    
    # Optionally sort the data
    if sort_data:
        print("\nSorting merged data by dataset_name and item_name...")
        merged_data.sort(key=lambda x: (x.get('dataset_name', ''), str(x.get('item_name', ''))))
    
    # Save the result
    print(f"\nSaving merged data to {output_path}...")
    try:
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(merged_data, f, indent=2, ensure_ascii=False)
        print(f"Successfully saved {len(merged_data)} items to {output_path}")
    except Exception as e:
        print(f"ERROR: Failed to save to {output_path}: {e}")
        raise


def main():
    # Hardcoded paths for evaluation data
    train_path = 'data/annotations_train.json'
    test_4_path = 'data/test_4.json'
    coral_test_path = 'data/coral_test.json'
    dc767_test_path = 'data/dc767_test.json'
    output_path = 'data/evaluation.json'
    
    # Fixed parameters
    coral_commit = 'e6b8fb6e5d6300faed70984eb445eb7ef13c0056'
    dc767_csv_url = 'https://raw.githubusercontent.com/NVIDIA/nv-ingest/6c9c74ebfdf37ddd397a7838d1e85dbb6f903447/data/digital_corpora_10k_annotations.csv'
    sort_data = False
    
    # Lists to hold processed data
    train_data = []
    test_4_data = []
    coral_full_data = []
    dc767_full_data = []
    
    # Process train dataset
    if not os.path.exists(train_path):
        print(f"Warning: {train_path} not found. Skipping train dataset.")
    else:
        print("=== Processing Train Dataset ===")
        
        # Load train data and drop annotations field
        with open(train_path, 'r') as f:
            raw_train_data = json.load(f)
        
        # Remove annotations field from each entry
        train_data = []
        for item in raw_train_data:
            train_entry = {
                'item_name': item['item_name'],
                'dataset_name': item['dataset_name'],
                'question': item['question'],
                'gt_answer': item['gt_answer'],
                'gen_answer': item['gen_answer']
            }
            train_data.append(train_entry)
        
        print(f"Loaded {len(train_data)} train entries (annotations dropped)")
    
    # Process test_4 dataset
    if not os.path.exists(test_4_path):
        print(f"Warning: {test_4_path} not found. Skipping test_4 dataset.")
    else:
        print("\n=== Processing test_4 Dataset ===")
        
        # Load test_4 data (already complete, no joining needed)
        with open(test_4_path, 'r') as f:
            test_4_data = json.load(f)
        print(f"Loaded {len(test_4_data)} test_4 entries")
    
    # Process CORAL test dataset
    if not os.path.exists(coral_test_path):
        print(f"Warning: {coral_test_path} not found. Skipping CORAL test dataset.")
    else:
        print("\n=== Processing CORAL Test Dataset ===")
        
        # Load coral test data
        with open(coral_test_path, 'r') as f:
            coral_test_data = json.load(f)
        print(f"Loaded {len(coral_test_data)} coral test entries")
        
        # Download from HuggingFace
        train_array_data = download_coral_dataset(coral_commit)
        
        # Join the data
        coral_full_data = join_coral_data(coral_test_data, train_array_data)
    
    # Process dc767 test dataset
    if not os.path.exists(dc767_test_path):
        print(f"Warning: {dc767_test_path} not found. Skipping dc767 test dataset.")
    else:
        print("\n=== Processing dc767 Test Dataset ===")
        
        # Load dc767 test data
        with open(dc767_test_path, 'r') as f:
            dc767_test_data = json.load(f)
        print(f"Loaded {len(dc767_test_data)} dc767 test entries")
        
        # Download from GitHub
        temp_csv_path = "/tmp/digital_corpora_10k_annotations.csv"
        try:
            csv_data = download_dc767_csv(dc767_csv_url, temp_csv_path)
            
            # Join the data
            dc767_full_data = join_dc767_data(dc767_test_data, csv_data)
        finally:
            # Clean up temporary file
            if os.path.exists(temp_csv_path):
                os.remove(temp_csv_path)
                print(f"Cleaned up temporary file: {temp_csv_path}")
    
    # Merge and save the data
    if train_data or test_4_data or coral_full_data or dc767_full_data:
        print("\n=== Merging Evaluation Datasets ===")
        merge_and_save(train_data, test_4_data, coral_full_data, dc767_full_data, Path(output_path), sort_data)
        
        print("\n✅ Process completed successfully!")
        print(f"Final merged evaluation dataset saved to: {output_path}")
    else:
        print("\n⚠️  No data to merge. Please check your files.")
        return 1
    
    return 0


if __name__ == "__main__":
    exit(main())
