import pandas as pd
import numpy as np
import os
import argparse
from collections import defaultdict

from config.config import (
    PREDICTIONS_TEST_INTERNAL,
    PREDICTIONS_TEST_EXTERNAL,
    DISTANCES_TEST_INTERNAL,
    DISTANCES_TEST_EXTERNAL,
    RESULTS_TRAIN,
    RESULTS_TEST_INTERNAL,
    RESULTS_TEST_EXTERNAL,
)


def run_coverage_accuracy_best_N(dataset_type: str):
    L_values = np.round(np.arange(-0.2, 1.001, 0.005), 3)

    EMBEDDING_MODELS = {
        "mobilenet_v2": ("mobilenetv2_100", 1000, 224),
        "dinov1_s": ("vit_small_patch16_224_dino", 384, 224),
        "dinov1_b": ("vit_base_patch16_224_dino", 768, 224),
        "dinov1_b8": ("vit_base_patch8_224_dino", 768, 224),
        "dinov2_s": ("vit_small_patch14_dinov2.lvd142m", 384, 518),
        "dinov2_b": ("vit_base_patch14_dinov2.lvd142m", 768, 518),
    }

    PREDICTION_MODELS = [
        "resnet50",
        "resnet101",
        "shufflenet_v2_x1_0",
        "deit_tiny_patch16_224",
        "deit_small_patch16_224",
        "deit_base_patch16_224"
    ]

    # Select dataset paths
    if dataset_type == "EXTERNAL":
        prediction_test_folder = PREDICTIONS_TEST_EXTERNAL
        distances_folder = DISTANCES_TEST_EXTERNAL
        results_test_dir = RESULTS_TEST_EXTERNAL
    else:  # INTERNAL
        prediction_test_folder = PREDICTIONS_TEST_INTERNAL
        distances_folder = DISTANCES_TEST_INTERNAL
        results_test_dir = RESULTS_TEST_INTERNAL

    # Load best N per model/embedding from integrals table
    integrals_path = os.path.join(RESULTS_TRAIN, "integrals/integrals_by_model_and_embedding.csv")
    df = pd.read_csv(integrals_path)
    print(df)
    df = df.dropna(subset=['normalized_adjusted', 'N'])
    df['N'] = df['N'].astype(int)
    best_per_model = df.loc[df.groupby(['model', 'emb_model'])['normalized_adjusted'].idxmax()]
    best_per_model = best_per_model[['model', 'emb_model', 'N']].reset_index(drop=True)
    best_N_lookup = {(row['model'], row['emb_model']): row['N'] for _, row in best_per_model.iterrows()}

    for PREDICTION_MODEL_NAME in PREDICTION_MODELS:
        for EMBEDDING_MODEL_NAME in EMBEDDING_MODELS:
            key = (PREDICTION_MODEL_NAME, EMBEDDING_MODEL_NAME)
            if key not in best_N_lookup:
                continue
            best_N = best_N_lookup[key]
            print(f"\n== Processing {PREDICTION_MODEL_NAME} with {EMBEDDING_MODEL_NAME}, best N = {best_N} ==")
            predictions_path = os.path.join(prediction_test_folder, PREDICTION_MODEL_NAME + ".csv")
            similarity_path = os.path.join(distances_folder, EMBEDDING_MODEL_NAME + "_cosine_similarities.csv")

            try:
                print("Loading data...")
                pred_df = pd.read_csv(predictions_path, sep=None, engine="python")
                pred_df.columns = pred_df.columns.str.strip()
                pred_df['image_id'] = pred_df['image_id'].str.replace(".jpg", "", regex=False).str.strip()
                pred_df.set_index("image_id", inplace=True)

                sim_df = pd.read_csv(similarity_path, index_col=0)
                sim_df.index = sim_df.index.astype(str).str.strip()
                sim_df.columns = sim_df.columns.astype(str).str.strip()
            except Exception as e:
                print(f"Failed to load data for {PREDICTION_MODEL_NAME} and {EMBEDDING_MODEL_NAME}: {e}")
                continue

            output_dir = os.path.join(results_test_dir, "coverage_accuracy_best_N",
                                      PREDICTION_MODEL_NAME + "__" + EMBEDDING_MODEL_NAME)
            os.makedirs(output_dir, exist_ok=True)
            output_path = f"{output_dir}/coverage_accuracy_best_N.csv"
            print(output_path)
            if os.path.exists(output_path):
                os.remove(output_path)
                print(f"Cleared previous output file: {output_path}")

            print("Precomputing neighbor counts...")
            neighbors_within_L = defaultdict(dict)
            for query_id in sim_df.index:
                if query_id not in pred_df.index:
                    continue
                sims = sim_df.loc[query_id].drop(labels=[query_id], errors="ignore").astype(float)
                for L in L_values:
                    neighbors_within_L[query_id][L] = np.sum(sims.values >= L)
            print("Finished precomputing.")

            all_subset_results = []

            print("Starting evaluation by subset...")
            for subset in ['subset_A', 'subset_B', 'subset_C']:
                print(f"  Subset: {subset}")
                subset_df = pred_df[pred_df['subset'] == subset]
                if subset_df.empty:
                    continue

                new_results = []
                for L in L_values:
                    confident_preds = []
                    confident_true = []

                    for query_id in subset_df.index:
                        if query_id not in neighbors_within_L:
                            continue
                        if neighbors_within_L[query_id][L] > best_N:
                            pred_col = f"{PREDICTION_MODEL_NAME}_pred_class"
                            true_col = f"{PREDICTION_MODEL_NAME}_true_class"
                            confident_preds.append(subset_df.loc[query_id, pred_col])
                            confident_true.append(subset_df.loc[query_id, true_col])

                    if confident_preds:
                        accuracy = sum(p == t for p, t in zip(confident_preds, confident_true)) / len(confident_preds)
                        coverage = len(confident_preds) / len(subset_df)
                    else:
                        accuracy = 0.0
                        coverage = 0.0

                    new_results.append({
                        "N": best_N,
                        "L": round(L, 2),
                        "coverage": coverage,
                        "accuracy": accuracy,
                        "subset": subset.replace("subset_", "")
                    })

                all_subset_results.extend(new_results)

            # Save combined results for all subsets
            pd.DataFrame(all_subset_results).to_csv(output_path, index=False)
            print(f"Saved subset-wise best-N results to: {output_path}")

    print("\nAll model-embedding pairs completed.")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Run coverage/accuracy evaluation for INTERNAL or EXTERNAL datasets."
    )
    parser.add_argument(
        "dataset_type",
        nargs="?",
        default="INTERNAL",
        choices=["INTERNAL", "EXTERNAL"],
        help="Dataset type to process (default: INTERNAL)"
    )
    args = parser.parse_args()

    run_coverage_accuracy_best_N(args.dataset_type)