import os
import pickle
import pandas as pd
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from sklearn.metrics import ndcg_score
import numpy as np
import time
from flask import Flask, request, jsonify

print("STARTING SERVER")

DATASET_GPU_DEVICES = {
    'bright-fast': 0,
    'fever': 0, 
    'fiqa': 0,
    'hotpotqa': 0,
    'msmarco': 2,
    'nfcorpus': 1,
    'nq': 1,
    'scifact': 1
}

app = Flask(__name__)
start_time = time.time()

def log(msg):
    elapsed = time.time() - start_time
    print(f"[{elapsed:.2f}s] {msg}")

def print_gpu_memory(device_id, stage):
    if torch.cuda.is_available():
        device = f"cuda:{device_id}"
        allocated = torch.cuda.memory_allocated(device) / 1024**3
        reserved = torch.cuda.memory_reserved(device) / 1024**3
        total = torch.cuda.get_device_properties(device).total_memory / 1024**3
        log(f"GPU {device_id} {stage}: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved, {total:.2f}GB total")

def compute_batch_metrics_vectorized(batch_queries, doc_ids, similarity_matrix, excluded_sets, qrels, device, corpus):
    n_docs = len(doc_ids)
    doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(doc_ids)}
    
    batch_results = []
    
    for batch_idx, query_data in enumerate(batch_queries):
        query_id = query_data['query_id']
        target_docs = query_data['target_docs']
        excluded_set = excluded_sets.get(query_id, set())
        
        similarities = similarity_matrix[batch_idx]
        
        target_indices = torch.tensor([doc_id_to_idx[doc_id] for doc_id in target_docs if doc_id in doc_id_to_idx], device=device)
        
        if excluded_set:
            excluded_indices = torch.tensor([doc_id_to_idx[doc_id] for doc_id in excluded_set if doc_id in doc_id_to_idx], device=device)
            excluded_mask = torch.zeros(n_docs, dtype=torch.bool, device=device)
            excluded_mask[excluded_indices] = True
            valid_mask = ~excluded_mask
            
            valid_similarities = similarities[valid_mask]
            valid_indices = torch.where(valid_mask)[0]
            
            sorted_indices = torch.argsort(valid_similarities, descending=True)
            global_sorted_indices = valid_indices[sorted_indices]
        else:
            sorted_indices = torch.argsort(similarities, descending=True)
            global_sorted_indices = sorted_indices
        
        rank_lookup = torch.zeros(n_docs, dtype=torch.long, device=device)
        rank_lookup[global_sorted_indices] = torch.arange(len(global_sorted_indices), device=device) + 1
        
        target_ranks = rank_lookup[target_indices].cpu().tolist()
        target_ranks = [r for r in target_ranks if r > 0]
        
        if not target_ranks:
            results = {
                'query_id': query_id,
                'success@5': False, 'success@10': False, 'success@50': False, 'success@100': False,
                'ndcg@5': 0.0, 'ndcg@10': 0.0, 'ndcg@50': 0.0, 'ndcg@100': 0.0,
                'precision@5': 0.0, 'precision@10': 0.0, 'precision@50': 0.0, 'precision@100': 0.0,
                'recall@5': 0.0, 'recall@10': 0.0, 'recall@50': 0.0, 'recall@100': 0.0,
                'best_rank': -1, 'mrr': 0.0, 'map': 0.0, 'results': []
            }
            batch_results.append(results)
            continue
        
        best_rank = min(target_ranks)
        mrr = max(1.0/rank for rank in target_ranks)
        map_score = sum(1.0/rank for rank in target_ranks) / len(target_docs)
        
        results = {
            'query_id': query_id,
            'best_rank': best_rank,
            'mrr': mrr,
            'map': map_score
        }
        
        top_k_indices = global_sorted_indices[:100].cpu()
        top_k_docs = [doc_ids[idx] for idx in top_k_indices]
        top_k_similarities = similarities[top_k_indices].cpu()
        
        target_doc_set = set(target_docs)
        
        for k in [5, 10, 50, 100]:
            current_top_docs = top_k_docs[:k]
            current_similarities = top_k_similarities[:k]
            
            success = best_rank <= k
            precision = sum(1 for doc_id in current_top_docs if doc_id in target_doc_set) / k
            recall = sum(1 for doc_id in current_top_docs if doc_id in target_doc_set) / len(target_docs)
            
            relevance_scores = []
            for doc_id in current_top_docs:
                if doc_id in target_doc_set:
                    doc_qrels = qrels[(qrels.iloc[:, 0] == query_id) & (qrels.iloc[:, 1] == doc_id)]
                    if len(doc_qrels) > 0:
                        relevance_scores.append(float(doc_qrels.iloc[0, 2]))
                    else:
                        relevance_scores.append(1.0)
                else:
                    relevance_scores.append(0.0)
            
            if any(score > 0 for score in relevance_scores):
                ndcg = ndcg_score([relevance_scores], [current_similarities.tolist()], k=k)
            else:
                ndcg = 0.0
            
            results[f'success@{k}'] = success
            results[f'ndcg@{k}'] = ndcg
            results[f'precision@{k}'] = precision
            results[f'recall@{k}'] = recall
            
            if k == 5:
                query_results = []
                for i, idx in enumerate(top_k_indices[:5]):
                    doc_id = doc_ids[idx]
                    try:
                        if doc_id in corpus.index:
                            doc_text = str(corpus.loc[doc_id, 'title']) + ' ' + str(corpus.loc[doc_id, 'text'])
                        else:
                            doc_text = f"Document {doc_id} not found in corpus"
                    except:
                        doc_text = f"Error retrieving text for document {doc_id}"
                    
                    query_results.append({
                        'rank': i + 1,
                        'doc_id': doc_id,
                        'score': float(similarities[idx]),
                        'text': doc_text
                    })
                results['results'] = query_results
        
        batch_results.append(results)
    
    return batch_results

class BatchIRServer:
    def __init__(self):
        log("Loading sentence transformer model")
        self.model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
        self.datasets_loaded = {}
        self._load_all_datasets()
    
    def _load_all_datasets(self):
        for dataset_name, device_id in DATASET_GPU_DEVICES.items():
            device = f"cuda:{device_id}"
            log(f"=== {dataset_name.upper()} (GPU {device_id}) ===")
            print_gpu_memory(device_id, "before loading")
            
            base_path = f"./data/raw_data/{dataset_name}/{dataset_name}"
            embeddings_path = f"./data/embeddings/all_MiniLM_L6_v2/{dataset_name}"
            
            if not os.path.exists(base_path) or not os.path.exists(embeddings_path):
                log("Missing files, skipping")
                continue
            
            corpus = pd.read_json(f"{base_path}/corpus.jsonl", lines=True)
            corpus['_id'] = corpus['_id'].astype(str)
            corpus = corpus.set_index('_id')
            
            queries = pd.read_json(f"{base_path}/queries.jsonl", lines=True)
            queries['_id'] = queries['_id'].astype(str)
            queries = queries.set_index('_id')
            
            qrels = {}
            for split_name in ['train', 'dev', 'test']:
                split_path = f"{base_path}/qrels/{split_name}.tsv"
                if os.path.exists(split_path):
                    qrels[split_name] = pd.read_csv(split_path, sep='\t', dtype=str)
            
            with open(f"{embeddings_path}/embeddings.pkl", 'rb') as f:
                embeddings_np = pickle.load(f)
            with open(f"{embeddings_path}/ids.pkl", 'rb') as f:
                doc_ids_raw = [str(doc_id) for doc_id in pickle.load(f)]
            
            seen_docs = set()
            doc_ids = []
            unique_indices = []
            for i, doc_id in enumerate(doc_ids_raw):
                if doc_id not in seen_docs:
                    seen_docs.add(doc_id)
                    doc_ids.append(doc_id)
                    unique_indices.append(i)
            
            if len(unique_indices) < len(doc_ids_raw):
                log(f"Found {len(doc_ids_raw) - len(unique_indices)} duplicate documents, keeping {len(unique_indices)} unique")
                embeddings_np = embeddings_np[unique_indices]
            
            log(f"Embeddings shape: {embeddings_np.shape}, Total documents: {len(doc_ids)}")
            embeddings = torch.tensor(embeddings_np, dtype=torch.float32, device=device)
            del embeddings_np
            
            print_gpu_memory(device_id, "after loading")
            
            excluded_ids = {}
            if dataset_name == 'bright-fast':
                for query_id, row in queries.iterrows():
                    metadata = row.get('metadata', {})
                    if 'excluded_ids' in metadata:
                        excluded_ids[query_id] = set(metadata['excluded_ids'])
                
                if qrels:
                    for split_name, split_qrels in qrels.items():
                        for _, qrel_row in split_qrels.iterrows():
                            query_id = qrel_row.iloc[0]
                            corpus_id = qrel_row.iloc[1]
                            if query_id in excluded_ids and corpus_id in excluded_ids[query_id]:
                                excluded_ids[query_id].discard(corpus_id)
                                log(f"Removed target doc {corpus_id} from excluded_ids for query {query_id}")
            
            self.datasets_loaded[dataset_name] = {
                'embeddings': embeddings,
                'doc_ids': doc_ids,
                'corpus': corpus,
                'queries': queries,
                'qrels': qrels,
                'excluded_ids': excluded_ids,
                'device': device
            }
        
        log("=== FINAL GPU MEMORY USAGE ===")
        for device_id in [0, 1, 2]:
            print_gpu_memory(device_id, "final")
        
        log(f"Datasets loaded: {list(self.datasets_loaded.keys())}")
        for name, data in self.datasets_loaded.items():
            log(f"{name}: {data['embeddings'].shape} on {data['device']}")
        
        log("Server initialization completed")
    
    def batch_search(self, dataset_name, split, queries):
        if dataset_name not in self.datasets_loaded:
            return {"error": f"Dataset {dataset_name} not available"}
        
        data = self.datasets_loaded[dataset_name]
        
        if split not in data['qrels']:
            return {"error": f"{split} split not available for {dataset_name}"}
        
        qrels = data['qrels'][split]
        corpus = data['corpus']
        
        batch_queries = []
        search_texts = []
        
        for query in queries:
            query_id = str(query['query_id'])
            search_query = query['search_query']
            
            query_qrels = qrels[qrels.iloc[:, 0] == query_id]
            if len(query_qrels.columns) >= 3:
                query_qrels['score'] = pd.to_numeric(query_qrels.iloc[:, 2], errors='coerce')
                target_docs = query_qrels[query_qrels['score'] > 0].iloc[:, 1].tolist()
            else:
                target_docs = query_qrels.iloc[:, 1].tolist()
            
            if not target_docs:
                continue
            
            batch_queries.append({
                'query_id': query_id,
                'search_query': search_query,
                'target_docs': target_docs
            })
            search_texts.append(search_query)
        
        if not batch_queries:
            return {"error": "No valid queries found"}
        
        log(f"Processing batch of {len(batch_queries)} queries for {dataset_name}")
        
        query_embeddings = self.model.encode(search_texts, batch_size=32, show_progress_bar=False)
        query_tensor = torch.tensor(query_embeddings, dtype=torch.float32, device=data['device'])
        
        similarity_matrix = torch.mm(query_tensor, data['embeddings'].T)
        
        excluded_sets = {}
        if dataset_name == 'bright-fast':
            for query_data in batch_queries:
                query_id = query_data['query_id']
                excluded_sets[query_id] = data['excluded_ids'].get(query_id, set())
        
        results = compute_batch_metrics_vectorized(batch_queries, data['doc_ids'], similarity_matrix, excluded_sets, qrels, data['device'], corpus)
        
        del query_tensor, similarity_matrix
        torch.cuda.empty_cache()
        
        return {
            'dataset_name': dataset_name,
            'split': split,
            'batch_size': len(results),
            'results': results
        }

ir_server = BatchIRServer()

@app.route('/batch_search', methods=['POST'])
def batch_search():
    data = request.get_json()
    
    required_fields = ['dataset_name', 'split', 'queries']
    missing_fields = [field for field in required_fields if field not in data or data[field] is None]
    
    if missing_fields:
        return jsonify({"error": f"Missing required fields: {', '.join(missing_fields)}"}), 400
    
    if not isinstance(data['queries'], list) or len(data['queries']) == 0:
        return jsonify({"error": "queries must be a non-empty list"}), 400
    
    for i, query in enumerate(data['queries']):
        if not isinstance(query, dict) or 'query_id' not in query or 'search_query' not in query:
            return jsonify({"error": f"Query {i} must have 'query_id' and 'search_query' fields"}), 400
    
    result = ir_server.batch_search(data['dataset_name'], data['split'], data['queries'])
    return jsonify(result)

@app.route('/datasets', methods=['GET'])
def get_datasets():
    return jsonify({'available_datasets': list(ir_server.datasets_loaded.keys())})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000, debug=False)