import pandas as pd
import numpy as np
import os
from collections import defaultdict
import json
from config.config import BASE_DIR, PREDICTIONS_TRAIN, DISTANCES_TRAIN, RESULTS_TRAIN
def run_coverage_accuracy_sweep(base_path, N_values=None, L_values=None):
    if N_values is None:
        N_values = list(range(1, 31)) + list(range(40, 1000, 10))
    if L_values is None:
        L_values = np.round(np.arange(-0.2, 1.001, 0.005), 3)

    EMBEDDING_MODELS = {
        "mobilenet_v2": ("mobilenetv2_100", 1000, 224),
        "dinov1_s": ("vit_small_patch16_224_dino", 384, 224),
        "dinov1_b": ("vit_base_patch16_224_dino", 768, 224),
        "dinov1_b8": ("vit_base_patch8_224_dino", 768, 224),
        "dinov2_s": ("vit_small_patch14_dinov2.lvd142m", 384, 518),
        "dinov2_b": ("vit_base_patch14_dinov2.lvd142m", 768, 518),
    }

    PREDICTION_MODELS = [
        "resnet50",
        "resnet101",
        "shufflenet_v2_x1_0",
        "deit_tiny_patch16_224",
        "deit_small_patch16_224",
        "deit_base_patch16_224"
    ]

    for PREDICTION_MODEL_NAME in PREDICTION_MODELS:
        for EMBEDDING_MODEL_NAME in EMBEDDING_MODELS:
            print(f"\n== Processing {PREDICTION_MODEL_NAME} with {EMBEDDING_MODEL_NAME} ==")

            predictions_path = (str)(PREDICTIONS_TRAIN) +"/"+ PREDICTION_MODEL_NAME +".csv"
            similarity_path = (str)(DISTANCES_TRAIN) +"/"+ EMBEDDING_MODEL_NAME +"_cosine_similarities.csv"

            try:
                print("Loading data...")
                pred_df = pd.read_csv(predictions_path, sep=None, engine="python")
                pred_df.columns = pred_df.columns.str.strip()
                pred_df.set_index("image_id", inplace=True)
                pred_df.index = pred_df.index.str.replace(".jpg", "", regex=False).str.strip()

                sim_df = pd.read_csv(similarity_path, index_col=0)
                sim_df.index = sim_df.index.astype(str).str.strip()
                sim_df.columns = sim_df.columns.astype(str).str.strip()
            except Exception as e:
                print(f"Failed to load data for {PREDICTION_MODEL_NAME} and {EMBEDDING_MODEL_NAME}: {e}")
                continue

            output_dir = os.path.join(RESULTS_TRAIN, "coverage_accuracy",PREDICTION_MODEL_NAME +"__"+EMBEDDING_MODEL_NAME)
            os.makedirs(output_dir, exist_ok=True)
            output_path = f"{output_dir}/coverage_accuracy_sweep.csv"

            if os.path.exists(output_path):
                os.remove(output_path)
                print(f"Cleared previous output file: {output_path}")

            print("Precomputing neighbor counts...")
            neighbors_within_L = defaultdict(dict)
            for query_id in sim_df.index:
                if query_id not in pred_df.index:
                    continue
                sims = sim_df.loc[query_id].drop(labels=[query_id], errors="ignore").astype(float)
                #print(sims.values)
                for L in L_values:
                    count = np.sum(sims.values >= L)
                    neighbors_within_L[query_id][L] = count
                    #print(L,count)
            print("Finished precomputing.")
            print("Starting main loop...")
            results_written = False

            for N in N_values:
                new_results = []
                print(f"Processing N = {N}")
                for L in L_values:
                    confident_preds = []
                    confident_true = []

                    for query_id in neighbors_within_L:
                        if neighbors_within_L[query_id][L] > N:
                            confident_preds.append(pred_df.loc[query_id, f"{PREDICTION_MODEL_NAME}_pred_class"])
                            confident_true.append(pred_df.loc[query_id, f"{PREDICTION_MODEL_NAME}_true_class"])

                    if confident_preds:
                        accuracy = sum(p == t for p, t in zip(confident_preds, confident_true)) / len(confident_preds)
                        coverage = len(confident_preds) / len(pred_df)
                    else:
                        accuracy = 0.0
                        coverage = 0.0

                    new_results.append({
                        "N": N,
                        "L": round(L, 2),
                        "coverage": coverage,
                        "accuracy": accuracy
                    })

                if new_results:
                    batch_df = pd.DataFrame(new_results)
                    batch_df.to_csv(output_path, mode='a', index=False, header=not results_written)
                    results_written = True
                    print(f"Saved results after N = {N} to: {output_path}")

            print(f"Completed for {PREDICTION_MODEL_NAME} with {EMBEDDING_MODEL_NAME}")

# Example usage:
if __name__ == '__main__':

    base_dir = BASE_DIR

    run_coverage_accuracy_sweep(base_dir)
