import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from tslearn.datasets import UCR_UEA_datasets
from tslearn.barycenters import softdtw_barycenter, dtw_barycenter_averaging
from tslearn.metrics import soft_dtw, dtw
from ns_numba_ops import barycenter as nsdtw_barycenter
from ns_numba_ops import sdtw as ns_dtw
from ns_numba_ops import sdtw_value_and_grad

# --- Dataset loader ---
ucr = UCR_UEA_datasets()

def my_value_and_grad(X, Y, gamma):
    return sdtw_value_and_grad(X, Y, gamma=gamma)

def classify_dataset(name, max_iter=50):
    """Run DTW, Soft-DTW, and NS-DTW classification with 50-25-25 split."""
    X, y, _, _ = ucr.load_dataset(name)
    X = X.astype(np.float64)

    # --- Train/Val/Test split: 50/25/25 ---
    X_train, X_temp, y_train, y_temp = train_test_split(
        X, y, test_size=0.5, stratify=y
    )
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=0.5, stratify=y_temp
    )

    unique_labels = np.unique(y_train)
    results = {}

    # --- DTW (baseline, no gamma tuning) ---
    dba_centroids = [
        dtw_barycenter_averaging(X_train[y_train == label],
                                 init_barycenter=None, max_iter=max_iter)
        for label in unique_labels
    ]
    y_pred_dtw = np.array([
        np.argmin([dtw(x, c) for c in dba_centroids]) + 1
        for x in X_test
    ])
    results["dtw"] = accuracy_score(y_test, y_pred_dtw)

    # --- Soft-DTW (gamma tuned on validation) ---
    gammas = [0.1, 0.01, 0.001, 0.0001]
    best_gamma, best_val_acc = None, -1
    for gamma in gammas:
        centroids = [
            softdtw_barycenter(X_train[y_train == label].squeeze(),
                               gamma=gamma, init=None, max_iter=max_iter)
            for label in unique_labels
        ]
        y_pred_val = np.array([
            np.argmin([soft_dtw(x, c, gamma=gamma) for c in centroids]) + 1
            for x in X_val
        ])
        val_acc = accuracy_score(y_val, y_pred_val)
        if val_acc > best_val_acc:
            best_val_acc, best_gamma, best_centroids = val_acc, gamma, centroids

    # Evaluate best gamma on test set
    y_pred_soft = np.array([
        np.argmin([soft_dtw(x, c, gamma=best_gamma) for c in best_centroids]) + 1
        for x in X_test
    ])
    results[f"soft_dtw_best_gamma={best_gamma}"] = accuracy_score(y_test, y_pred_soft)

    # --- NS-DTW (gamma tuned on validation) ---
    best_gamma, best_val_acc = None, -1
    for gamma in gammas:
        centroids = [
            nsdtw_barycenter(X_train[y_train == label],
                             max_iter=max_iter,
                             X_init="euclidean_mean",
                             value_and_grad=lambda Xc, Yc: my_value_and_grad(Xc, Yc, gamma))
            for label in unique_labels
        ]
        y_pred_val = np.array([
            np.argmin([ns_dtw(x, c, gamma=gamma) for c in centroids]) + 1
            for x in X_val
        ])
        val_acc = accuracy_score(y_val, y_pred_val)
        if val_acc > best_val_acc:
            best_val_acc, best_gamma, best_centroids = val_acc, gamma, centroids

    # Evaluate best gamma on test set
    y_pred_ns = np.array([
        np.argmin([ns_dtw(x, c, gamma=best_gamma) for c in best_centroids]) + 1
        for x in X_test
    ])
    results[f"ns_dtw_best_gamma={best_gamma}"] = accuracy_score(y_test, y_pred_ns)

    return results

# --- Loop over datasets ---
all_results = []
lists = ['Adiac',
'ArrowHead', 'Beef', 'BeetleFly', 'BirdChicken',  'Car', 'CBF',
'ChlorineConcentration', 'CinCECGTorso', 'Coffee', 'Computers', 'CricketX', 'CricketY',
'CricketZ',  'DiatomSizeReduction', 'DistalPhalanxOutlineAgeGroup', 'DistalPhalanxOutlineCorrect',
'DistalPhalanxTW',  'Earthquakes', 'ECG200',
'ECG5000', 'ECGFiveDays',
'FaceAll', 'FaceFour', 'FacesUCR',    'GunPoint',
'Ham',     'MedicalImages',  'MiddlePhalanxOutlineAgeGroup',
'MiddlePhalanxOutlineCorrect', 'MiddlePhalanxTW', 'MoteStrain',
'ProximalPhalanxTW', 'RefrigerationDevices',  'ScreenType',  'ShapeletSim', 'ShapesAll', 'SmallKitchenAppliances',
 'SonyAIBORobotSurface1', 'SonyAIBORobotSurface2',  'SyntheticControl',
 'Trace', 'TwoLeadECG',
 'Wine', 'WordSynonyms', 'Worms', 'WormsTwoClass']
for dataset_name in lists:
    print("-" * 10, dataset_name, "-" * 10)
    print(f"Processing {dataset_name} ...")
    accs = classify_dataset(dataset_name)
    all_results.append({"Dataset": dataset_name, **accs})


# --- Collect results ---
df_results = pd.DataFrame(all_results).set_index("Dataset").sort_index()
print("\nFinal Accuracy Table:")
print(df_results)

# Save results
df_results.to_csv("ucr_classification_accuracy.csv")
