#!/usr/bin/env python3

import json
import requests
import pandas as pd
import time
import os
import argparse
from datetime import datetime
from tqdm import tqdm
import asyncio
from concurrent.futures import ThreadPoolExecutor
from collections import deque

def parse_arguments():
    parser = argparse.ArgumentParser(description='Run IR evaluation with VLLM server')
    
    parser.add_argument('--dataset-name', type=str, required=True,
                       help='Dataset name (e.g., fever)')
    parser.add_argument('--split', type=str, required=True,
                       help='Dataset split (e.g., test)')
    parser.add_argument('--model-name', type=str, required=True,
                       help='VLLM model name (e.g., LFM2-350M)')
    parser.add_argument('--experiment-id', type=str, required=True,
                       help='Experiment identifier (e.g., seed_001)')
    parser.add_argument('--batch-size', type=int, required=True,
                       help='Batch size for processing (e.g., 64)')
    parser.add_argument('--vllm-server-url', type=str, default='http://localhost:5001',
                       help='VLLM server URL (default: http://localhost:5001)')
    parser.add_argument('--ir-server-url', type=str, default='http://localhost:5000',
                       help='IR server URL (default: http://localhost:5000)')
    parser.add_argument('--sql-server-url', type=str, default='http://localhost:8000',
                       help='SQL server URL (default: http://localhost:8000)')
    parser.add_argument('--max-turns', type=int, default=5,
                       help='Maximum turns per query (default: 5)')
    
    return parser.parse_args()

class SQLClient:
    def __init__(self, sql_server_url, dataset_name, split, experiment_id, model_name):
        self.sql_server_url = sql_server_url
        self.base_params = {
            'dataset_name': dataset_name,
            'split': split,
            'experiment_id': experiment_id,
            'model_name': model_name
        }
    
    def _make_request(self, endpoint, data):
        try:
            response = requests.post(f"{self.sql_server_url}/{endpoint}", 
                                   json=data, timeout=30)
            if response.status_code == 200:
                return response.json()
            else:
                raise Exception(f"SQL Server error {response.status_code}: {response.text}")
        except requests.exceptions.RequestException as e:
            raise Exception(f"Failed to connect to SQL server: {e}")
    
    def get_finished_query_ids(self):
        result = self._make_request('get_finished_queries', self.base_params)
        return set(result['finished_ids'])
    
    def get_avg_success_rate(self):
        result = self._make_request('get_success_rate', self.base_params)
        return result['success_rate'], result['total_finished']
    
    def get_query_history(self, query_id):
        data = {**self.base_params, 'query_id': query_id}
        result = self._make_request('get_query_history', data)
        return result['history']
    
    def delete_incomplete_query(self, query_id):
        data = {**self.base_params, 'query_id': query_id}
        result = self._make_request('delete_incomplete_query', data)
        return result['deleted_count']
    
    def save_result(self, result_data):
        data = {**self.base_params, 'result_data': result_data}
        result = self._make_request('save_result', data)
        return result['result_id']
    
    def batch_save_results(self, results_batch):
        data = {**self.base_params, 'results_batch': results_batch}
        result = self._make_request('batch_save_results', data)
        return result['saved_count']

class VLLMEvaluationRunner:
    def __init__(self, args):
        self.args = args
        self.queries_df = None
        self.pending_queue = deque()
        self.active_queries = {}
        self.completed_queries = 0
        self.total_queries = 0
        self.sql_client = SQLClient(
            args.sql_server_url, args.dataset_name, args.split, 
            args.experiment_id, args.model_name
        )
        self.pending_results = []
        self.batch_save_size = 10
    
    def get_avg_success_rate(self):
        return self.sql_client.get_avg_success_rate()
    
    def get_finished_query_ids(self):
        return self.sql_client.get_finished_query_ids()
    
    def get_query_history(self, query_id):
        return self.sql_client.get_query_history(query_id)
    
    def delete_incomplete_query(self, query_id):
        return self.sql_client.delete_incomplete_query(query_id)
    
    def prepare_result_data(self, result, turn_id, think_response, search_query, is_finished=False):
        user_query_text = self.queries_df.loc[result['query_id'], 'text']
        top_k_doc_ids = json.dumps([r['doc_id'] for r in result.get('results', [])])
        
        top_k_texts = []
        for i, r in enumerate(result.get('results', [])[:5]):
            doc_text = r.get('text', r['doc_id'])
            if len(doc_text) > 300:
                truncated_text = doc_text[:150] + "... " + doc_text[-86:]
            else:
                truncated_text = doc_text
            top_k_texts.append(f'{i+1}. """{truncated_text}"""')
        top_k_texts_json = json.dumps(top_k_texts)
        
        return {
            'query_id': result['query_id'],
            'user_query_text': user_query_text,
            'turn_id': turn_id,
            'think_response': think_response,
            'search_query': search_query,
            'top_k_doc_ids': top_k_doc_ids,
            'top_k_texts': top_k_texts_json,
            'success_at_5': result['success@5'],
            'success_at_10': result['success@10'],
            'success_at_50': result['success@50'],
            'success_at_100': result['success@100'],
            'ndcg_at_5': result['ndcg@5'],
            'ndcg_at_10': result['ndcg@10'],
            'ndcg_at_50': result['ndcg@50'],
            'ndcg_at_100': result['ndcg@100'],
            'precision_at_5': result['precision@5'],
            'precision_at_10': result['precision@10'],
            'precision_at_50': result['precision@50'],
            'precision_at_100': result['precision@100'],
            'recall_at_5': result['recall@5'],
            'recall_at_10': result['recall@10'],
            'recall_at_50': result['recall@50'],
            'recall_at_100': result['recall@100'],
            'best_rank': result['best_rank'],
            'mrr': result['mrr'],
            'map_score': result['map'],
            'is_finished': is_finished
        }
    
    def save_result(self, result, turn_id, think_response, search_query, is_finished=False):
        result_data = self.prepare_result_data(result, turn_id, think_response, search_query, is_finished)
        self.pending_results.append(result_data)
        
        if len(self.pending_results) >= self.batch_save_size:
            self.flush_pending_results()
    
    def flush_pending_results(self):
        if self.pending_results:
            self.sql_client.batch_save_results(self.pending_results)
            self.pending_results.clear()
    
    def load_queries(self):
        base_path = f"./data/raw_data/{self.args.dataset_name}/{self.args.dataset_name}"
        queries_path = f"{base_path}/queries.jsonl"
        qrels_path = f"{base_path}/qrels/{self.args.split}.tsv"
        
        queries = pd.read_json(queries_path, lines=True)
        queries['_id'] = queries['_id'].astype(str)
        queries = queries.set_index('_id')
        self.queries_df = queries
        
        qrels = pd.read_csv(qrels_path, sep='\t', dtype=str)
        available_query_ids = qrels.iloc[:, 0].unique()
        
        return queries, available_query_ids
    
    def call_ir_server(self, batch_queries):
        payload = {
            'dataset_name': self.args.dataset_name,
            'split': self.args.split,
            'queries': batch_queries
        }
        
        try:
            response = requests.post(f"{self.args.ir_server_url}/batch_search", 
                                   json=payload, timeout=120)
            if response.status_code == 200:
                return response.json()
            else:
                raise Exception(f"IR Server error: {response.status_code}")
        except requests.exceptions.RequestException as e:
            raise Exception(f"Failed to connect to IR server: {e}")
    
    def call_vllm_server(self, prompts, eos_token, temperature=0.3, max_tokens=512):
        payload = {
            'prompts': prompts,
            'eos_token': eos_token,
            'temperature': temperature,
            'max_tokens': max_tokens
        }
        
        try:
            response = requests.post(f"{self.args.vllm_server_url}/batch_inference", 
                                   json=payload, timeout=600)
            if response.status_code == 200:
                return response.json()
            else:
                raise Exception(f"VLLM Server error: {response.status_code} - {response.text}")
        except requests.RequestException as e:
            raise Exception(f"VLLM request failed: {e}")
    
    def build_conversation_context(self, original_query, history):
        conversation = f"<user_query>{original_query}</user_query>\n\n"
        
        for turn_id, think_resp, search_q, top_texts, success in history:
            conversation += f"<think>{think_resp}</think>\n\n"
            conversation += f"<search_query>{search_q}</search_query>\n\n"
            
            if top_texts:
                top_k_list = json.loads(top_texts)
                top_k_content = "\n".join(top_k_list[:5])
                conversation += f"<top_k_response>{top_k_content}</top_k_response>\n\n"
            else:
                conversation += f"<top_k_response>No results found</top_k_response>\n\n"
        
        return conversation
    
    def generate_batch_llm_responses(self, query_contexts):
        think_prompts = []
        for context in query_contexts:
            original_query = context['original_query']
            history = context['history']
            
            if len(history) == 0:
                prompt = f"<user_query>{original_query}</user_query>\n\n<think>"
            else:
                conversation = self.build_conversation_context(original_query, history)
                prompt = conversation + "<think>"
            
            think_prompts.append(prompt)
        
        think_response = self.call_vllm_server(think_prompts, "</think>", temperature=0.3, max_tokens=512)
        think_results = think_response['results']
        
        search_prompts = []
        for i, context in enumerate(query_contexts):
            think_text = think_results[i]['text']
            
            # Clean up double EOS tokens
            if think_text.endswith("</think>"):
                think_text = think_text[:-8]
            
            original_query = context['original_query']
            history = context['history']
            
            if len(history) == 0:
                search_prompt = f"<user_query>{original_query}</user_query>\n\n<think>{think_text}</think>\n\n<search_query>"
            else:
                conversation = self.build_conversation_context(original_query, history)
                search_prompt = conversation + f"<think>{think_text}</think>\n\n<search_query>"
            
            search_prompts.append(search_prompt)
        
        search_response = self.call_vllm_server(search_prompts, "</search_query>", temperature=0.3, max_tokens=512)
        search_results = search_response['results']
        
        updated_contexts = []
        for i, context in enumerate(query_contexts):
            think_text = think_results[i]['text']
            search_text = search_results[i]['text']
            
            # Clean up double EOS tokens for both
            if think_text.endswith("</think>"):
                think_text = think_text[:-8]
            if search_text.endswith("</search_query>"):
                search_text = search_text[:-15]
            
            updated_context = {
                'query_id': context['query_id'],
                'original_query': context['original_query'],
                'think_response': think_text,
                'search_query': search_text,
                'turn': len(context['history']) + 1
            }
            updated_contexts.append(updated_context)
        
        return updated_contexts
    
    def execute_batch_ir_search(self, contexts):
        batch_queries = []
        for context in contexts:
            batch_queries.append({
                'query_id': context['query_id'],
                'search_query': context['search_query']
            })
        
        server_response = self.call_ir_server(batch_queries)
        return server_response['results']
    
    def create_query_context(self, query_id):
        original_query = self.queries_df.loc[query_id, 'text']
        history = self.get_query_history(query_id)
        return {
            'query_id': query_id,
            'original_query': original_query,
            'history': history
        }
    
    def fill_active_batch(self):
        while len(self.active_queries) < self.args.batch_size and len(self.pending_queue) > 0:
            query_id = self.pending_queue.popleft()
            self.delete_incomplete_query(query_id)
            self.active_queries[query_id] = self.create_query_context(query_id)
    
    def process_completed_query(self, query_id, result, context):
        is_successful = result['success@5']
        is_final_turn = context['turn'] >= self.args.max_turns
        is_finished = is_successful or is_final_turn
        
        self.save_result(result, context['turn'], context['think_response'], 
                        context['search_query'], is_finished)
        
        if is_finished:
            self.completed_queries += 1
            del self.active_queries[query_id]
            return True
        else:
            self.active_queries[query_id] = self.create_query_context(query_id)
            return False
    
    def run_evaluation(self):
        queries_df, available_query_ids = self.load_queries()
        finished_query_ids = self.get_finished_query_ids()
        
        pending_query_ids = [qid for qid in available_query_ids if qid not in finished_query_ids]
        self.pending_queue.extend(pending_query_ids)
        self.total_queries = len(available_query_ids)
        self.completed_queries = len(finished_query_ids)
        
        if finished_query_ids:
            avg_success, completed_count = self.get_avg_success_rate()
            tqdm.write(f"Resuming: {len(finished_query_ids)} completed queries (Avg Success: {avg_success:.2%})")
        
        pbar = tqdm(total=self.total_queries, initial=self.completed_queries, 
                   desc=f"{self.args.dataset_name}-{self.args.split}")
        
        try:
            while len(self.pending_queue) > 0 or len(self.active_queries) > 0:
                self.fill_active_batch()
                
                if len(self.active_queries) == 0:
                    break
                
                query_contexts = list(self.active_queries.values())
                
                batch_contexts = self.generate_batch_llm_responses(query_contexts)
                
                results = self.execute_batch_ir_search(batch_contexts)
                
                queries_to_remove = []
                for j, result in enumerate(results):
                    context = batch_contexts[j]
                    query_id = result['query_id']
                    
                    is_completed = self.process_completed_query(query_id, result, context)
                    if is_completed:
                        queries_to_remove.append(query_id)
                
                pbar.update(len(queries_to_remove))
                
                avg_success, completed_count = self.get_avg_success_rate()
                pbar.set_postfix({
                    'Success': f'{avg_success:.2%}',
                    'Active': len(self.active_queries),
                    'Pending': len(self.pending_queue)
                })
            
            pbar.close()
            
        finally:
            self.flush_pending_results()
            tqdm.write("Evaluation completed. All results saved to database.")

def main():
    args = parse_arguments()
    
    try:
        response = requests.get(f"{args.sql_server_url}/health", timeout=5)
        if response.status_code != 200:
            print(f"Warning: SQL server health check failed. Make sure sql_server.py is running on {args.sql_server_url}")
            return
    except requests.exceptions.RequestException:
        print(f"Error: Cannot connect to SQL server at {args.sql_server_url}")
        print("Please start sql_server.py first")
        return
    
    try:
        response = requests.get(f"{args.vllm_server_url}/health", timeout=5)
        if response.status_code != 200:
            print(f"Warning: VLLM server health check failed. Make sure vllm_server.py is running on {args.vllm_server_url}")
            return
    except requests.exceptions.RequestException:
        print(f"Error: Cannot connect to VLLM server at {args.vllm_server_url}")
        print("Please start the VLLM server first")
        return
    
    try:
        response = requests.get(f"{args.ir_server_url}/datasets", timeout=5)
        if response.status_code != 200:
            print(f"Warning: IR server health check failed. Make sure ir_server.py is running on {args.ir_server_url}")
            return
    except requests.exceptions.RequestException:
        print(f"Error: Cannot connect to IR server at {args.ir_server_url}")
        print("Please start the IR server first")
        return
    
    print(f"Connected to SQL server: {args.sql_server_url}")
    print(f"Connected to VLLM server: {args.vllm_server_url}")  
    print(f"Connected to IR server: {args.ir_server_url}")
    print(f"Starting evaluation: {args.dataset_name}-{args.split} with {args.model_name}")
    
    runner = VLLMEvaluationRunner(args)
    runner.run_evaluation()

if __name__ == '__main__':
    main()