"""
Worker script for parallel bounded inspiration selection v2.
Each worker processes a subset of files based on worker_id.
Saves both selections and embeddings/similarities.
"""

import os
import json
import argparse
from typing import List, Dict, Optional, Tuple
from tqdm import tqdm
import numpy as np

# Tier configuration
TIERS = [
    ("hard", 0.90, 0.92),
    ("medium", 0.92, 0.94),
    ("easy", 0.94, 0.97),
]


def get_model_and_util():
    from sentence_transformers import SentenceTransformer, util
    print("Loading SPECTER2 model...")
    model = SentenceTransformer('allenai/specter2_base')
    print("Model loaded!")
    return model, util


def format_paper_text(title: str, abstract: str) -> str:
    if abstract:
        return f"{title}[SEP]{abstract}"
    return title


def get_tier_name(similarity: float) -> Optional[str]:
    for tier_name, low, high in TIERS:
        if low <= similarity < high:
            return tier_name
    if similarity < 0.90:
        return "below_range"
    if similarity >= 0.97:
        return "above_range"
    return None


def calculate_all_similarities(
    gt_title: str,
    gt_abstract: str,
    recommendations: List[Dict],
    model,
    util
) -> List[Tuple[Dict, float]]:
    """Calculate similarities between GT and all valid recommendations."""
    valid_recs = [r for r in recommendations if r.get('title') and r.get('abstract')]
    
    if not valid_recs:
        return []
    
    gt_text = format_paper_text(gt_title, gt_abstract)
    gt_emb = model.encode(gt_text, convert_to_tensor=True)
    
    rec_texts = [format_paper_text(r['title'], r['abstract']) for r in valid_recs]
    rec_embs = model.encode(rec_texts, convert_to_tensor=True)
    
    similarities = util.cos_sim(gt_emb, rec_embs)[0].cpu().numpy()
    
    return [(rec, float(sim)) for rec, sim in zip(valid_recs, similarities)]


def select_top_per_tier(rec_sim_pairs: List[Tuple[Dict, float]]) -> Dict[str, Optional[Dict]]:
    """From each tier, select the TOP candidate (highest similarity within tier)."""
    selections = {}
    
    for tier_name, low, high in TIERS:
        tier_candidates = [(r, s) for r, s in rec_sim_pairs if low <= s < high]
        
        if tier_candidates:
            best = max(tier_candidates, key=lambda x: x[1])
            selections[tier_name] = {
                'title': best[0].get('title', ''),
                'abstract': best[0].get('abstract', ''),
                'paperId': best[0].get('paperId', ''),
                'similarity': best[1],
                'tier': tier_name,
                'tier_range': f"[{low}, {high})"
            }
        else:
            selections[tier_name] = None
    
    return selections


def process_file(
    sft_path: str,
    rec_path: str,
    model,
    util,
    save_embeddings: bool = True
) -> Tuple[Optional[Dict], Optional[Dict]]:
    """Process a single file and select bounded inspirations."""
    try:
        with open(sft_path) as f:
            sft_data = json.load(f)
        with open(rec_path) as f:
            rec_data = json.load(f)
    except Exception:
        return None, None
    
    inspirations = sft_data.get('inspiration', [])
    rec_inspirations = rec_data.get('inspirations', [])
    
    if len(inspirations) != len(rec_inspirations):
        return None, None
    
    result = {
        'research_question': sft_data.get('research_question', ''),
        'background_survey': sft_data.get('background_survey', ''),
        'hypothesis': sft_data.get('hypothesis', ''),
        'hypothesis_components': sft_data.get('hypothesis_components', {}),
        'inspirations': []
    }
    
    embeddings_result = {'inspirations': []} if save_embeddings else None
    
    for idx, (insp, rec_insp) in enumerate(zip(inspirations, rec_inspirations)):
        gt_title = insp.get('found_title', '')
        gt_abstract = insp.get('found_abstract', '')
        gt_concept = insp.get('insp_concise', '') or insp.get('insp', '')
        relation = insp.get('relation', '')
        delta_hyp = result['hypothesis_components'].get(str(idx), '')
        
        insp_result = {
            'idx': idx,
            'gt_title': gt_title,
            'gt_abstract': gt_abstract,
            'gt_concept': gt_concept,
            'relation': relation,
            'delta_hypothesis': delta_hyp,
            'bounded_selections': {},
            'total_candidates': 0,
            'candidates_per_tier': {}
        }
        
        emb_result = {
            'idx': idx,
            'gt_title': gt_title,
            'all_similarities': []
        } if save_embeddings else None
        
        if not gt_title:
            insp_result['error'] = 'no_gt_title'
            result['inspirations'].append(insp_result)
            if save_embeddings:
                emb_result['error'] = 'no_gt_title'
                embeddings_result['inspirations'].append(emb_result)
            continue
        
        recommendations = rec_insp.get('recommendations', [])
        rec_sim_pairs = calculate_all_similarities(
            gt_title, gt_abstract, recommendations, model, util
        )
        
        if not rec_sim_pairs:
            insp_result['error'] = 'no_valid_recommendations'
            result['inspirations'].append(insp_result)
            if save_embeddings:
                emb_result['error'] = 'no_valid_recommendations'
                embeddings_result['inspirations'].append(emb_result)
            continue
        
        insp_result['total_candidates'] = len(rec_sim_pairs)
        
        for tier_name, low, high in TIERS:
            count = sum(1 for _, s in rec_sim_pairs if low <= s < high)
            insp_result['candidates_per_tier'][tier_name] = count
        
        selections = select_top_per_tier(rec_sim_pairs)
        insp_result['bounded_selections'] = selections
        insp_result['tiers_with_selection'] = sum(1 for v in selections.values() if v is not None)
        
        result['inspirations'].append(insp_result)
        
        if save_embeddings:
            sorted_pairs = sorted(rec_sim_pairs, key=lambda x: x[1], reverse=True)
            emb_result['all_similarities'] = [
                {
                    'title': rec.get('title', ''),
                    'abstract': rec.get('abstract', '')[:500],
                    'paperId': rec.get('paperId', ''),
                    'similarity': sim,
                    'tier': get_tier_name(sim)
                }
                for rec, sim in sorted_pairs
            ]
            embeddings_result['inspirations'].append(emb_result)
    
    return result, embeddings_result


def main():
    parser = argparse.ArgumentParser(description='Worker for bounded inspiration selection v2')
    parser.add_argument('--sft_dir', type=str, required=True)
    parser.add_argument('--rec_dir', type=str, required=True)
    parser.add_argument('--output_dir', type=str, required=True)
    parser.add_argument('--embeddings_dir', type=str, default=None)
    parser.add_argument('--worker_id', type=int, required=True)
    parser.add_argument('--num_workers', type=int, required=True)
    
    args = parser.parse_args()
    
    model, util = get_model_and_util()
    
    os.makedirs(args.output_dir, exist_ok=True)
    save_embeddings = args.embeddings_dir is not None
    if save_embeddings:
        os.makedirs(args.embeddings_dir, exist_ok=True)
    
    all_files = sorted([f for f in os.listdir(args.rec_dir) if f.endswith('.json')])
    my_files = [f for i, f in enumerate(all_files) if i % args.num_workers == args.worker_id]
    
    print(f"Worker {args.worker_id}: Processing {len(my_files)} / {len(all_files)} files")
    print(f"Save embeddings: {save_embeddings}")
    
    stats = {
        'processed': 0,
        'inspirations': 0,
        'with_valid_recs': 0,
        'with_selection': 0,
        'tier_coverage': {tier[0]: 0 for tier in TIERS}
    }
    
    for filename in tqdm(my_files, desc=f"Worker {args.worker_id}"):
        sft_path = os.path.join(args.sft_dir, filename)
        rec_path = os.path.join(args.rec_dir, filename)
        
        if not os.path.exists(sft_path):
            continue
        
        result, emb_result = process_file(sft_path, rec_path, model, util, save_embeddings)
        
        if result:
            output_path = os.path.join(args.output_dir, filename)
            with open(output_path, 'w') as f:
                json.dump(result, f, indent=2, ensure_ascii=False)
            
            if save_embeddings and emb_result:
                emb_path = os.path.join(args.embeddings_dir, filename)
                with open(emb_path, 'w') as f:
                    json.dump(emb_result, f, indent=2, ensure_ascii=False)
            
            stats['processed'] += 1
            for insp in result['inspirations']:
                stats['inspirations'] += 1
                if insp.get('total_candidates', 0) > 0:
                    stats['with_valid_recs'] += 1
                if insp.get('tiers_with_selection', 0) > 0:
                    stats['with_selection'] += 1
                
                selections = insp.get('bounded_selections', {})
                for tier_name in stats['tier_coverage']:
                    if selections.get(tier_name) is not None:
                        stats['tier_coverage'][tier_name] += 1
    
    print(f"\nWorker {args.worker_id} done!")
    print(f"  Files processed: {stats['processed']}")
    print(f"  Inspirations: {stats['inspirations']}")
    print(f"  With valid recs: {stats['with_valid_recs']}")
    print(f"  With selection: {stats['with_selection']}")
    print(f"  Tier coverage: {stats['tier_coverage']}")


if __name__ == "__main__":
    main()

