import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from tslearn.datasets import UCR_UEA_datasets
from ns_numba_ops import sdtw as ns_dtw
from tslearn.metrics import soft_dtw


# --- Dataset loader ---
ucr = UCR_UEA_datasets()

def classify_dataset(name, gammas=[0.1, 0.01, 0.001, 0.0001]):
    """
    1-NN classifier for DTW, Soft-DTW, NS-DTW.
    Uses 50-25-25 train/val/test split.
    """
    # --- Load dataset ---
    X, y, _, _ = ucr.load_dataset(name)
    X = X.astype(np.float64)

    # --- Train/Val/Test split ---
    X_train, X_temp, y_train, y_temp = train_test_split(
        X, y, test_size=0.5, stratify=y
    )
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=0.5, stratify=y_temp
    )

    results = {}



    # ============================================================
    #  1-NN evaluation for a distance function
    # ============================================================
    def evaluate_gamma_1nn(distance_fn, X_train, X_query, gamma=None):
        """
        Computes 1-NN predictions for X_query using X_train.

        distance_fn(x, y, gamma=gamma) must return a scalar distance.
        """
        preds = []
        for x in X_query:
            dists = [distance_fn(x, xtr, gamma) for xtr in X_train]
            preds.append(y_train[np.argmin(dists)])
        return np.array(preds)


    # ============================================================
    # 2) Soft-DTW 1-NN (gamma tuned)
    # ============================================================
    best_gamma, best_val_acc = None, -1
    for gamma in gammas:
        y_pred_val = evaluate_gamma_1nn(soft_dtw, X_train, X_val, gamma)
        acc = accuracy_score(y_val, y_pred_val)
        if acc > best_val_acc:
            best_val_acc = acc
            best_gamma = gamma

    # Evaluate best gamma on test set
    y_pred_soft = evaluate_gamma_1nn(soft_dtw, X_train, X_test, best_gamma)
    results[f"soft_dtw_1nn_best_gamma={best_gamma}"] = accuracy_score(y_test, y_pred_soft)


    # ============================================================
    # 3) NS-DTW 1-NN (gamma tuned)
    # ============================================================
    best_gamma, best_val_acc = None, -1
    for gamma in gammas:
        y_pred_val = evaluate_gamma_1nn(ns_dtw, X_train, X_val, gamma)
        acc = accuracy_score(y_val, y_pred_val)
        if acc > best_val_acc:
            best_val_acc = acc
            best_gamma = gamma

    # Evaluate best gamma
    y_pred_ns = evaluate_gamma_1nn(ns_dtw, X_train, X_test, best_gamma)
    results[f"ns_dtw_1nn_best_gamma={best_gamma}"] = accuracy_score(y_test, y_pred_ns)

    return results

# --- Loop over datasets ---
all_results = []
lists = ['Adiac',
'ArrowHead', 'Beef', 'BeetleFly', 'BirdChicken',  'Car', 'CBF',
'ChlorineConcentration', 'CinCECGTorso', 'Coffee', 'Computers', 'CricketX', 'CricketY',
'CricketZ',  'DiatomSizeReduction', 'DistalPhalanxOutlineAgeGroup', 'DistalPhalanxOutlineCorrect',
'DistalPhalanxTW',  'Earthquakes', 'ECG200',
'ECG5000', 'ECGFiveDays',
'FaceAll', 'FaceFour', 'FacesUCR',    'GunPoint',
'Ham',     'MedicalImages',  'MiddlePhalanxOutlineAgeGroup',
'MiddlePhalanxOutlineCorrect', 'MiddlePhalanxTW', 'MoteStrain',
'ProximalPhalanxTW', 'RefrigerationDevices',  'ScreenType',  'ShapeletSim', 'ShapesAll', 'SmallKitchenAppliances',
 'SonyAIBORobotSurface1', 'SonyAIBORobotSurface2',  'SyntheticControl',
 'Trace', 'TwoLeadECG',
 'Wine', 'WordSynonyms', 'Worms', 'WormsTwoClass']
for dataset_name in lists:
    print("-" * 10, dataset_name, "-" * 10)
    try:
        print(f"Processing {dataset_name} ...")
        accs = classify_dataset(dataset_name)
        all_results.append({"Dataset": dataset_name, **accs})
    except Exception as e:
        print(f"❌ {dataset_name} failed: {e}")

    # --- Collect results ---
    df_results = pd.DataFrame(all_results).set_index("Dataset").sort_index()

    # Save results
    df_results.to_csv("ucr_classification_1nn.csv")
