import pandas as pd
import numpy as np
import os
from collections import defaultdict
import json
from config.config import BASE_DIR, PREDICTIONS_TRAIN, DISTANCES_TRAIN, RESULTS_TRAIN
from config.config_restore import DISTANCES_TRAIN


def run_coverage_accuracy_sweep(base_path, N_values=None, L_values=None):
    if N_values is None:
        N_values = list(range(10, 501, 10))
    if L_values is None:
        L_values = np.round(np.arange(1.0, 0, -0.005), 3)

    EMBEDDING_MODELS = {
        "mobilenet_v2": "mobilenetv2_100",
        "dinov1_s": "vit_small_patch16_224_dino",
        "dinov1_b": "vit_base_patch16_224_dino",
        "dinov1_b8": "vit_base_patch8_224_dino",
        "dinov2_s": "vit_small_patch14_dinov2.lvd142m",
        "dinov2_b": "vit_base_patch14_dinov2.lvd142m"
    }

    PREDICTION_MODELS = [
        "resnet50",
        "resnet101",
        "shufflenet_v2_x1_0",
        "deit_tiny_patch16_224",
        "deit_small_patch16_224",
        "deit_base_patch16_224"
    ]

    for PREDICTION_MODEL_NAME in PREDICTION_MODELS:
        for EMBEDDING_MODEL_NAME in EMBEDDING_MODELS:
            print(f"\n== Processing {PREDICTION_MODEL_NAME} with {EMBEDDING_MODEL_NAME} ==")

            predictions_path = os.path.join(PREDICTIONS_TRAIN,PREDICTION_MODEL_NAME +".csv")
            similarity_path = os.path.join(DISTANCES_TRAIN,EMBEDDING_MODEL_NAME +"_cosine_similarities.csv")


            print("Loading data...")
            pred_df = pd.read_csv(predictions_path, sep=None, engine="python")
            pred_df.columns = pred_df.columns.str.strip()
            pred_df.set_index("image_id", inplace=True)
            pred_df.index = pred_df.index.str.replace(".jpg", "", regex=False).str.strip()

            sim_df = pd.read_csv(similarity_path, index_col=0)
            sim_df.index = sim_df.index.astype(str).str.strip()
            sim_df.columns = sim_df.columns.astype(str).str.strip()

            output_dir = os.path.join(RESULTS_TRAIN, "coverage_accuracy", PREDICTION_MODEL_NAME+"__"+EMBEDDING_MODEL_NAME)

            os.makedirs(output_dir, exist_ok=True)
            output_path = f"{output_dir}/coverage_accuracy_sweep.csv"

            if os.path.exists(output_path):
                os.remove(output_path)
                print(f"Cleared previous output file: {output_path}")

            # === Precompute neighbor counts using for-loop ===
            print("Precomputing neighbor counts...")
            neighbors_within_L = defaultdict(dict)
            for query_id in sim_df.index:
                if query_id not in pred_df.index:
                    continue

                sims = sim_df.loc[query_id].drop(labels=[query_id], errors="ignore").astype(float)
                # print(sims.values)

                for L in L_values:

                    count = np.sum(sims.values >= L)
                    neighbors_within_L[query_id][L] = count
                    # print(L,count)

            print("Finished precomputing.")
            print("Starting main loop...")
            results_written = False

            for N in N_values:
                new_results = []
                print(f"Processing N = {N}")
                for L in L_values:
                    confident_preds = []
                    confident_true = []

                    for query_id in neighbors_within_L:
                        if neighbors_within_L[query_id][L] > N:
                            try:
                                pred = pred_df.loc[query_id, f"{PREDICTION_MODEL_NAME}_pred_class"]
                                true = pred_df.loc[query_id, f"{PREDICTION_MODEL_NAME}_true_class"]
                                confident_preds.append(pred)
                                confident_true.append(true)
                            except KeyError:
                                continue

                    if confident_preds:
                        accuracy = sum(p == t for p, t in zip(confident_preds, confident_true)) / len(confident_preds)
                        coverage = len(confident_preds) / len(pred_df)
                    else:
                        accuracy = 0.0
                        coverage = 0.0

                    new_results.append({
                        "N": N,
                        "L": round(L, 2),
                        "coverage": coverage,
                        "accuracy": accuracy
                    })

                if new_results:
                    batch_df = pd.DataFrame(new_results)
                    batch_df.to_csv(output_path, mode='a', index=False, header=not results_written)
                    results_written = True
                    print(f"Saved results after N = {N} to: {output_path}")

            print(f"Completed for {PREDICTION_MODEL_NAME} with {EMBEDDING_MODEL_NAME}")

# === Entry point ===
if __name__ == '__main__':
    base_dir = BASE_DIR
    run_coverage_accuracy_sweep(base_dir)
