import sqlite3
import pandas as pd
from datetime import datetime
import os

def get_test_set_sizes():
    test_sizes = {}
    datasets = ['bright-fast', 'fever', 'fiqa', 'hotpotqa', 'msmarco', 'nfcorpus', 'nq', 'scifact']
    
    for dataset in datasets:
        test_path = f"./data/raw_data/{dataset}/{dataset}/qrels/test.tsv"
        if os.path.exists(test_path):
            df = pd.read_csv(test_path, sep='\t', dtype=str)
            test_sizes[dataset] = df.iloc[:, 0].nunique()
    
    return test_sizes

def get_metrics_summary(db_path):
    conn = sqlite3.connect(db_path)
    
    query = '''
    SELECT 
        dataset_name,
        model_name,
        split,
        experiment_id,
        COUNT(DISTINCT user_query_id) as total_queries,
        SUM(CASE WHEN success_at_5 = 1 AND is_finished = 1 THEN 1 ELSE 0 END) as successful_queries,
        SUM(CASE WHEN success_at_5 = 1 AND is_finished = 1 AND turn_id > 1 THEN 1 ELSE 0 END) as successful_after_turn1,
        ROUND(AVG(CASE WHEN success_at_5 = 1 AND is_finished = 1 THEN turn_id ELSE NULL END), 2) as average_successful_depth,
        ROUND(AVG(success_at_5), 3) as success_rate_at_5,
        ROUND(AVG(success_at_10), 3) as success_rate_at_10,
        ROUND(AVG(success_at_50), 3) as success_rate_at_50,
        ROUND(AVG(success_at_100), 3) as success_rate_at_100,
        ROUND(AVG(ndcg_at_5), 3) as avg_ndcg_at_5,
        ROUND(AVG(ndcg_at_10), 3) as avg_ndcg_at_10,
        ROUND(AVG(ndcg_at_50), 3) as avg_ndcg_at_50,
        ROUND(AVG(ndcg_at_100), 3) as avg_ndcg_at_100,
        ROUND(AVG(precision_at_5), 3) as avg_precision_at_5,
        ROUND(AVG(precision_at_10), 3) as avg_precision_at_10,
        ROUND(AVG(precision_at_50), 3) as avg_precision_at_50,
        ROUND(AVG(precision_at_100), 3) as avg_precision_at_100,
        ROUND(AVG(recall_at_5), 3) as avg_recall_at_5,
        ROUND(AVG(recall_at_10), 3) as avg_recall_at_10,
        ROUND(AVG(recall_at_50), 3) as avg_recall_at_50,
        ROUND(AVG(recall_at_100), 3) as avg_recall_at_100,
        ROUND(AVG(best_rank), 2) as avg_best_rank,
        ROUND(AVG(mrr), 3) as avg_mrr,
        ROUND(AVG(map_score), 3) as avg_map
    FROM evaluation_results 
    WHERE is_finished = 1
    GROUP BY dataset_name, model_name, split, experiment_id
    ORDER BY dataset_name, model_name, split, experiment_id
    '''
    
    df = pd.read_sql_query(query, conn)
    df = df[~df['model_name'].str.contains('gpt-oss', case=False, na=False)]
    conn.close()
    
    return df

def generate_unified_metrics(strict_db="evaluation_results_extrinsic_.db", flexible_db="evaluation_results.db", output_file="METRICS.md"):
    test_sizes = get_test_set_sizes()
    
    try:
        if os.path.exists(strict_db):
            df_strict = get_metrics_summary(strict_db)
            strict_available = True
        else:
            raise Exception("not present")
    except:
        df_strict = pd.DataFrame()
        strict_available = False
    
    try:
        df_flexible = get_metrics_summary(flexible_db)
        flexible_available = True
    except:
        df_flexible = pd.DataFrame()
        flexible_available = False
    
    with open(output_file, 'w') as f:
        f.write("# Unified Evaluation Results\n\n")
        f.write(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
        f.write(f"Strict Database: {strict_db} {'(Available)' if strict_available else '(Not Available)'}\n")
        f.write(f"Flexible Database: {flexible_db} {'(Available)' if flexible_available else '(Not Available)'}\n\n")
        
        if strict_available:
            f.write("## Completion Status (Strict Database)\n\n")
            models_strict = sorted(df_strict['model_name'].unique()) if len(df_strict) > 0 else []
            datasets_strict = sorted(df_strict['dataset_name'].unique()) if len(df_strict) > 0 else []
            
            if models_strict and datasets_strict:
                valid_seed_counts_strict = {}
                for dataset in datasets_strict:
                    valid_seed_counts_strict[dataset] = {}
                    for model in models_strict:
                        valid_seeds = 0
                        dataset_model_rows = df_strict[(df_strict['dataset_name'] == dataset) & 
                                                      (df_strict['model_name'] == model) & 
                                                      (df_strict['split'] == 'test')]
                        
                        for _, row in dataset_model_rows.iterrows():
                            expected_size = test_sizes.get(dataset, 0)
                            if row['total_queries'] == expected_size:
                                valid_seeds += 1
                        
                        valid_seed_counts_strict[dataset][model] = valid_seeds
                
                f.write("| Dataset |")
                for model in models_strict:
                    f.write(f" {model} |")
                f.write(" Test Size |\n")
                
                f.write("|---------|")
                for _ in models_strict:
                    f.write("---------|")
                f.write("-----------|\n")
                
                for dataset in datasets_strict:
                    f.write(f"| {dataset} |")
                    for model in models_strict:
                        valid_count = valid_seed_counts_strict[dataset].get(model, 0)
                        f.write(f" {valid_count} |")
                    test_size = test_sizes.get(dataset, 'N/A')
                    f.write(f" {test_size} |\n")
        
        if flexible_available:
            f.write("\n## Completion Status (Flexible Database)\n\n")
            models_flexible = sorted(df_flexible['model_name'].unique()) if len(df_flexible) > 0 else []
            datasets_flexible = sorted(df_flexible['dataset_name'].unique()) if len(df_flexible) > 0 else []
            
            if models_flexible and datasets_flexible:
                valid_seed_counts_flexible = {}
                for dataset in datasets_flexible:
                    valid_seed_counts_flexible[dataset] = {}
                    for model in models_flexible:
                        valid_seeds = 0
                        dataset_model_rows = df_flexible[(df_flexible['dataset_name'] == dataset) & 
                                                        (df_flexible['model_name'] == model) & 
                                                        (df_flexible['split'] == 'test')]
                        
                        for _, row in dataset_model_rows.iterrows():
                            expected_size = test_sizes.get(dataset, 0)
                            if row['total_queries'] == expected_size:
                                valid_seeds += 1
                        
                        valid_seed_counts_flexible[dataset][model] = valid_seeds
                
                f.write("| Dataset |")
                for model in models_flexible:
                    f.write(f" {model} |")
                f.write(" Test Size |\n")
                
                f.write("|---------|")
                for _ in models_flexible:
                    f.write("---------|")
                f.write("-----------|\n")
                
                for dataset in datasets_flexible:
                    f.write(f"| {dataset} |")
                    for model in models_flexible:
                        valid_count = valid_seed_counts_flexible[dataset].get(model, 0)
                        f.write(f" {valid_count} |")
                    test_size = test_sizes.get(dataset, 'N/A')
                    f.write(f" {test_size} |\n")
        
        all_datasets = set()
        if strict_available and len(df_strict) > 0:
            all_datasets.update(df_strict['dataset_name'].unique())
        if flexible_available and len(df_flexible) > 0:
            all_datasets.update(df_flexible['dataset_name'].unique())
        
        all_datasets = sorted(list(all_datasets))
        
        metrics_cols = ['success_rate_at_5', 'success_rate_at_10', 'success_rate_at_50', 'success_rate_at_100', 
                       'avg_ndcg_at_5', 'avg_ndcg_at_10', 'avg_ndcg_at_50', 'avg_ndcg_at_100',
                       'avg_precision_at_5', 'avg_precision_at_10', 'avg_precision_at_50', 'avg_precision_at_100',
                       'avg_recall_at_5', 'avg_recall_at_10', 'avg_recall_at_50', 'avg_recall_at_100',
                       'avg_best_rank', 'avg_mrr', 'avg_map', 'average_successful_depth']
        
        for dataset in all_datasets:
            f.write(f"\n## {dataset.upper()} Dataset Results\n\n")
            
            all_results = []
            
            if strict_available and dataset in df_strict['dataset_name'].values:
                dataset_strict_data = df_strict[(df_strict['dataset_name'] == dataset) & (df_strict['split'] == 'test')]
                
                for model in dataset_strict_data['model_name'].unique():
                    model_data = dataset_strict_data[dataset_strict_data['model_name'] == model]
                    complete_seeds = []
                    for _, row in model_data.iterrows():
                        expected_size = test_sizes.get(dataset, 0)
                        if row['total_queries'] == expected_size:
                            complete_seeds.append(row)
                    
                    if len(complete_seeds) == 3:
                        result_row = {'model': model, 'source': 'strict', 'seeds': '3'}
                        for metric in metrics_cols:
                            values = [row[metric] for row in complete_seeds]
                            mean_val = sum(values) / len(values)
                            std_val = (sum((x - mean_val) ** 2 for x in values) / len(values)) ** 0.5
                            result_row[metric] = (mean_val, std_val)
                        all_results.append(result_row)
            
            if flexible_available and dataset in df_flexible['dataset_name'].values:
                dataset_flexible_data = df_flexible[(df_flexible['dataset_name'] == dataset) & (df_flexible['split'] == 'test')]
                
                for model in dataset_flexible_data['model_name'].unique():
                    if not any(r['model'] == model and r['source'] == 'strict' for r in all_results):
                        model_data = dataset_flexible_data[dataset_flexible_data['model_name'] == model]
                        
                        if len(model_data) > 0:
                            all_seeds = []
                            for _, row in model_data.iterrows():
                                all_seeds.append(row)
                            
                            num_seeds = len(all_seeds)
                            result_row = {'model': model, 'source': 'flexible', 'seeds': str(num_seeds)}
                            
                            for metric in metrics_cols:
                                values = [row[metric] for row in all_seeds]
                                
                                if num_seeds == 1:
                                    mean_val = values[0]
                                    result_row[metric] = (mean_val, None)
                                else:
                                    mean_val = sum(values) / len(values)
                                    std_val = (sum((x - mean_val) ** 2 for x in values) / len(values)) ** 0.5
                                    result_row[metric] = (mean_val, std_val)
                            
                            all_results.append(result_row)
            
            if all_results:
                all_results.sort(key=lambda x: x['success_rate_at_5'][0], reverse=True)
                
                f.write("| Model | Success@5 | Success@10 | Success@50 | Success@100 | NDCG@5 | NDCG@10 | NDCG@50 | NDCG@100 | "
                       "Precision@5 | Precision@10 | Precision@50 | Precision@100 | Recall@5 | Recall@10 | Recall@50 | Recall@100 | "
                       "Best Rank | MRR | MAP | Avg. Depth | Seeds | Source |\n")
                f.write("|-------|-----------|------------|------------|-------------|--------|---------|---------|---------|"
                       "----------|-----------|-----------|-----------|----------|-----------|-----------|-----------|"
                       "-----------|-----|-----|------------|-------|--------|\n")
                
                for result_row in all_results:
                    f.write(f"| {result_row['model']} |")
                    
                    for metric in metrics_cols:
                        mean_val, std_val = result_row[metric]
                        if std_val is None:
                            f.write(f" {mean_val:.3f} |")
                        else:
                            f.write(f" {mean_val:.3f}±{std_val:.3f} |")
                    
                    f.write(f" {result_row['seeds']} | {result_row['source']} |\n")
            else:
                f.write("No data available for this dataset.\n")
    
    return output_file

if __name__ == '__main__':
    output_file = generate_unified_metrics()
    print(f"Unified metrics generated: {output_file}")