import sqlite3
import json
import time
from datetime import datetime
from flask import Flask, request, jsonify
from contextlib import contextmanager

app = Flask(__name__)
start_time = time.time()

def log(msg):
    elapsed = time.time() - start_time
    print(f"[{elapsed:.2f}s] SQL Server: {msg}")

class SQLServer:
    def __init__(self, db_path='evaluation_results.db'):
        self.db_path = db_path
        self.init_database()
        log(f"Initialized SQLite database at {db_path}")
    
    @contextmanager
    def get_connection(self):
        conn = sqlite3.connect(self.db_path, timeout=30.0)
        try:
            # Enable WAL mode for better concurrency
            conn.execute('PRAGMA journal_mode=WAL')
            conn.execute('PRAGMA synchronous=NORMAL')
            conn.execute('PRAGMA busy_timeout=30000')
            yield conn
        finally:
            conn.close()
    
    def init_database(self):
        with self.get_connection() as conn:
            cursor = conn.cursor()
            cursor.execute('''
                CREATE TABLE IF NOT EXISTS evaluation_results (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    dataset_name TEXT,
                    split TEXT,
                    user_query_id TEXT,
                    user_query_text TEXT,
                    turn_id INTEGER,
                    think_response TEXT,
                    search_query TEXT,
                    top_k_doc_ids TEXT,
                    top_k_texts TEXT,
                    success_at_5 BOOLEAN,
                    success_at_10 BOOLEAN,
                    success_at_50 BOOLEAN,
                    success_at_100 BOOLEAN,
                    ndcg_at_5 REAL,
                    ndcg_at_10 REAL,
                    ndcg_at_50 REAL,
                    ndcg_at_100 REAL,
                    precision_at_5 REAL,
                    precision_at_10 REAL,
                    precision_at_50 REAL,
                    precision_at_100 REAL,
                    recall_at_5 REAL,
                    recall_at_10 REAL,
                    recall_at_50 REAL,
                    recall_at_100 REAL,
                    best_rank INTEGER,
                    mrr REAL,
                    map_score REAL,
                    is_finished BOOLEAN DEFAULT FALSE,
                    timestamp TEXT,
                    model_name TEXT,
                    experiment_id TEXT
                )
            ''')
            
            # Create indexes for better query performance
            cursor.execute('''
                CREATE INDEX IF NOT EXISTS idx_query_lookup 
                ON evaluation_results(dataset_name, split, experiment_id, model_name, user_query_id)
            ''')
            cursor.execute('''
                CREATE INDEX IF NOT EXISTS idx_finished_queries 
                ON evaluation_results(dataset_name, split, experiment_id, model_name, is_finished)
            ''')
            
            conn.commit()
    
    def get_finished_query_ids(self, dataset_name, split, experiment_id, model_name):
        with self.get_connection() as conn:
            cursor = conn.cursor()
            cursor.execute('''
                SELECT DISTINCT user_query_id FROM evaluation_results 
                WHERE dataset_name = ? AND split = ? AND experiment_id = ? AND model_name = ? AND is_finished = TRUE
            ''', (dataset_name, split, experiment_id, model_name))
            finished_ids = [row[0] for row in cursor.fetchall()]
            return finished_ids
    
    def get_avg_success_rate(self, dataset_name, split, experiment_id, model_name):
        with self.get_connection() as conn:
            cursor = conn.cursor()
            cursor.execute('''
                SELECT COUNT(*) as total_finished,
                       SUM(CASE WHEN success_at_5 = 1 THEN 1 ELSE 0 END) as successful
                FROM evaluation_results 
                WHERE dataset_name = ? AND split = ? AND experiment_id = ? AND model_name = ? AND is_finished = TRUE
            ''', (dataset_name, split, experiment_id, model_name))
            result = cursor.fetchone()
            
            total, successful = result[0], result[1] if result[1] else 0
            if total == 0:
                return 0.0, 0
            return successful / total, total
    
    def get_query_history(self, dataset_name, split, experiment_id, model_name, query_id):
        with self.get_connection() as conn:
            cursor = conn.cursor()
            cursor.execute('''
                SELECT turn_id, think_response, search_query, top_k_texts, success_at_5
                FROM evaluation_results 
                WHERE dataset_name = ? AND split = ? AND experiment_id = ? AND model_name = ? AND user_query_id = ?
                ORDER BY turn_id
            ''', (dataset_name, split, experiment_id, model_name, query_id))
            history = cursor.fetchall()
            return history
    
    def delete_incomplete_query(self, dataset_name, split, experiment_id, model_name, query_id):
        with self.get_connection() as conn:
            cursor = conn.cursor()
            cursor.execute('''
                DELETE FROM evaluation_results 
                WHERE dataset_name = ? AND split = ? AND experiment_id = ? AND model_name = ? AND user_query_id = ? AND is_finished = FALSE
            ''', (dataset_name, split, experiment_id, model_name, query_id))
            deleted_count = cursor.rowcount
            conn.commit()
            return deleted_count
    
    def save_result(self, dataset_name, split, experiment_id, model_name, result_data):
        with self.get_connection() as conn:
            cursor = conn.cursor()
            
            timestamp = datetime.now().isoformat()
            
            cursor.execute('''
                INSERT INTO evaluation_results (
                    dataset_name, split, user_query_id, user_query_text, turn_id, think_response, search_query,
                    top_k_doc_ids, top_k_texts, success_at_5, success_at_10, success_at_50, success_at_100,
                    ndcg_at_5, ndcg_at_10, ndcg_at_50, ndcg_at_100,
                    precision_at_5, precision_at_10, precision_at_50, precision_at_100,
                    recall_at_5, recall_at_10, recall_at_50, recall_at_100,
                    best_rank, mrr, map_score, is_finished, timestamp, model_name, experiment_id
                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            ''', (
                dataset_name, split, result_data['query_id'], result_data['user_query_text'], 
                result_data['turn_id'], result_data['think_response'], result_data['search_query'],
                result_data['top_k_doc_ids'], result_data['top_k_texts'], 
                result_data['success_at_5'], result_data['success_at_10'], result_data['success_at_50'], result_data['success_at_100'],
                result_data['ndcg_at_5'], result_data['ndcg_at_10'], result_data['ndcg_at_50'], result_data['ndcg_at_100'],
                result_data['precision_at_5'], result_data['precision_at_10'], result_data['precision_at_50'], result_data['precision_at_100'],
                result_data['recall_at_5'], result_data['recall_at_10'], result_data['recall_at_50'], result_data['recall_at_100'],
                result_data['best_rank'], result_data['mrr'], result_data['map_score'], 
                result_data['is_finished'], timestamp, model_name, experiment_id
            ))
            conn.commit()
            return cursor.lastrowid
    
    def batch_save_results(self, dataset_name, split, experiment_id, model_name, results_batch):
        with self.get_connection() as conn:
            cursor = conn.cursor()
            timestamp = datetime.now().isoformat()
            
            batch_data = []
            for result_data in results_batch:
                batch_data.append((
                    dataset_name, split, result_data['query_id'], result_data['user_query_text'], 
                    result_data['turn_id'], result_data['think_response'], result_data['search_query'],
                    result_data['top_k_doc_ids'], result_data['top_k_texts'], 
                    result_data['success_at_5'], result_data['success_at_10'], result_data['success_at_50'], result_data['success_at_100'],
                    result_data['ndcg_at_5'], result_data['ndcg_at_10'], result_data['ndcg_at_50'], result_data['ndcg_at_100'],
                    result_data['precision_at_5'], result_data['precision_at_10'], result_data['precision_at_50'], result_data['precision_at_100'],
                    result_data['recall_at_5'], result_data['recall_at_10'], result_data['recall_at_50'], result_data['recall_at_100'],
                    result_data['best_rank'], result_data['mrr'], result_data['map_score'], 
                    result_data['is_finished'], timestamp, model_name, experiment_id
                ))
            
            cursor.executemany('''
                INSERT INTO evaluation_results (
                    dataset_name, split, user_query_id, user_query_text, turn_id, think_response, search_query,
                    top_k_doc_ids, top_k_texts, success_at_5, success_at_10, success_at_50, success_at_100,
                    ndcg_at_5, ndcg_at_10, ndcg_at_50, ndcg_at_100,
                    precision_at_5, precision_at_10, precision_at_50, precision_at_100,
                    recall_at_5, recall_at_10, recall_at_50, recall_at_100,
                    best_rank, mrr, map_score, is_finished, timestamp, model_name, experiment_id
                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            ''', batch_data)
            conn.commit()
            return len(batch_data)

# Initialize SQL server
sql_server = SQLServer()

@app.route('/get_finished_queries', methods=['POST'])
def get_finished_queries():
    data = request.get_json()
    required_fields = ['dataset_name', 'split', 'experiment_id', 'model_name']
    
    if not all(field in data for field in required_fields):
        return jsonify({"error": f"Missing required fields: {required_fields}"}), 400
    
    try:
        finished_ids = sql_server.get_finished_query_ids(
            data['dataset_name'], data['split'], data['experiment_id'], data['model_name']
        )
        return jsonify({"finished_ids": finished_ids})
    except Exception as e:
        log(f"Error getting finished queries: {e}")
        return jsonify({"error": str(e)}), 500

@app.route('/get_success_rate', methods=['POST'])
def get_success_rate():
    data = request.get_json()
    required_fields = ['dataset_name', 'split', 'experiment_id', 'model_name']
    
    if not all(field in data for field in required_fields):
        return jsonify({"error": f"Missing required fields: {required_fields}"}), 400
    
    try:
        success_rate, total = sql_server.get_avg_success_rate(
            data['dataset_name'], data['split'], data['experiment_id'], data['model_name']
        )
        return jsonify({"success_rate": success_rate, "total_finished": total})
    except Exception as e:
        log(f"Error getting success rate: {e}")
        return jsonify({"error": str(e)}), 500

@app.route('/get_query_history', methods=['POST'])
def get_query_history():
    data = request.get_json()
    required_fields = ['dataset_name', 'split', 'experiment_id', 'model_name', 'query_id']
    
    if not all(field in data for field in required_fields):
        return jsonify({"error": f"Missing required fields: {required_fields}"}), 400
    
    try:
        history = sql_server.get_query_history(
            data['dataset_name'], data['split'], data['experiment_id'], 
            data['model_name'], data['query_id']
        )
        return jsonify({"history": history})
    except Exception as e:
        log(f"Error getting query history: {e}")
        return jsonify({"error": str(e)}), 500

@app.route('/delete_incomplete_query', methods=['POST'])
def delete_incomplete_query():
    data = request.get_json()
    required_fields = ['dataset_name', 'split', 'experiment_id', 'model_name', 'query_id']
    
    if not all(field in data for field in required_fields):
        return jsonify({"error": f"Missing required fields: {required_fields}"}), 400
    
    try:
        deleted_count = sql_server.delete_incomplete_query(
            data['dataset_name'], data['split'], data['experiment_id'], 
            data['model_name'], data['query_id']
        )
        return jsonify({"deleted_count": deleted_count})
    except Exception as e:
        log(f"Error deleting incomplete query: {e}")
        return jsonify({"error": str(e)}), 500

@app.route('/save_result', methods=['POST'])
def save_result():
    data = request.get_json()
    required_fields = ['dataset_name', 'split', 'experiment_id', 'model_name', 'result_data']
    
    if not all(field in data for field in required_fields):
        return jsonify({"error": f"Missing required fields: {required_fields}"}), 400
    
    try:
        result_id = sql_server.save_result(
            data['dataset_name'], data['split'], data['experiment_id'], 
            data['model_name'], data['result_data']
        )
        return jsonify({"result_id": result_id})
    except Exception as e:
        log(f"Error saving result: {e}")
        return jsonify({"error": str(e)}), 500

@app.route('/batch_save_results', methods=['POST'])
def batch_save_results():
    data = request.get_json()
    required_fields = ['dataset_name', 'split', 'experiment_id', 'model_name', 'results_batch']
    
    if not all(field in data for field in required_fields):
        return jsonify({"error": f"Missing required fields: {required_fields}"}), 400
    
    if not isinstance(data['results_batch'], list) or len(data['results_batch']) == 0:
        return jsonify({"error": "results_batch must be a non-empty list"}), 400
    
    try:
        saved_count = sql_server.batch_save_results(
            data['dataset_name'], data['split'], data['experiment_id'], 
            data['model_name'], data['results_batch']
        )
        return jsonify({"saved_count": saved_count})
    except Exception as e:
        log(f"Error batch saving results: {e}")
        return jsonify({"error": str(e)}), 500

@app.route('/health', methods=['GET'])
def health_check():
    return jsonify({"status": "healthy", "uptime": time.time() - start_time})

@app.route('/stats', methods=['GET'])
def get_stats():
    try:
        with sql_server.get_connection() as conn:
            cursor = conn.cursor()
            cursor.execute('SELECT COUNT(*) FROM evaluation_results')
            total_records = cursor.fetchone()[0]
            
            cursor.execute('SELECT COUNT(DISTINCT experiment_id) FROM evaluation_results')
            unique_experiments = cursor.fetchone()[0]
            
            cursor.execute('SELECT COUNT(DISTINCT model_name) FROM evaluation_results')
            unique_models = cursor.fetchone()[0]
            
        return jsonify({
            "total_records": total_records,
            "unique_experiments": unique_experiments,
            "unique_models": unique_models
        })
    except Exception as e:
        return jsonify({"error": str(e)}), 500

if __name__ == '__main__':
    log("Starting SQL server on port 8000")
    app.run(host='0.0.0.0', port=8000, debug=False)