import time
import numpy as np
import pandas as pd
from tslearn.datasets import UCR_UEA_datasets
from kmeans import TimeSeriesKMeans
from tslearn.metrics import dtw


# Load dataset list
ucr = UCR_UEA_datasets()

# --- Helper functions ---
def compute_initial_centroids(X, y, method):
    """Compute initial centroids either as euclidean mean or random member per class."""
    centroids = []
    for label in np.unique(y):
        class_data = X[y == label]
        if method == "euclidean":
            centroids.append(np.mean(class_data, axis=0))
        else:
            centroids.append(class_data[np.random.randint(len(class_data))])
    return np.array(centroids)

def dtw_kmeans_loss(X, centroids, labels):
    """Compute clustering loss based on DTW distance."""
    loss = 0
    for i in range(len(X)):
        loss += dtw(X[i], centroids[labels[i]]) ** 2
    return loss / len(X)  # average loss

def run_clustering(X_train, y_train, method="dtw", gamma=None, init_type="euclidean"):
    """Run k-means clustering and return average DTW loss and runtime."""
    n_clusters = len(np.unique(y_train))
    init_centroids = compute_initial_centroids(X_train, y_train, init_type)

    metric_params = {}
    if gamma is not None:
        metric_params = {"gamma": gamma}

    model = TimeSeriesKMeans(
        n_clusters=n_clusters,
        metric=method,
        init=init_centroids,
        max_iter=30,
        metric_params=metric_params,
        verbose=False
    )
    start = time.time()
    model.fit(X_train)
    duration = time.time() - start
    loss = dtw_kmeans_loss(X_train, model.cluster_centers_, model.labels_)
    return loss, duration

# --- Main loop ---
all_results = []

# (display_name, metric, gamma)
methods = [
    ("dtw", "dtw", None),
    ("softdtw_g0.1", "softdtw", 0.1),
    ("softdtw_g0.01", "softdtw", 0.01),
    ("softdtw_g0.001", "softdtw", 0.001),
    ("softdtw_g0.001", "softdtw", 0.0001),
    ("nsdtw_g0.1", "ns_dtw", 0.1),
    ("nsdtw_g0.01", "ns_dtw", 0.01),
    ("nsdtw_g0.001", "ns_dtw", 0.001),
    ("nsdtw_g0.0001", "ns_dtw", 0.0001)
]

inits = ["euclidean"]

lists = ['Adiac',
'ArrowHead', 'Beef', 'BeetleFly', 'BirdChicken',  'Car', 'CBF',
'ChlorineConcentration', 'CinCECGTorso', 'Coffee', 'Computers', 'CricketX', 'CricketY',
'CricketZ',  'DiatomSizeReduction', 'DistalPhalanxOutlineAgeGroup', 'DistalPhalanxOutlineCorrect',
'DistalPhalanxTW',  'Earthquakes', 'ECG200',
'ECG5000', 'ECGFiveDays',
'FaceAll', 'FaceFour', 'FacesUCR',    'GunPoint',
'Ham',     'MedicalImages',  'MiddlePhalanxOutlineAgeGroup',
'MiddlePhalanxOutlineCorrect', 'MiddlePhalanxTW', 'MoteStrain',
'ProximalPhalanxTW', 'RefrigerationDevices',  'ScreenType',  'ShapeletSim', 'ShapesAll', 'SmallKitchenAppliances',
 'SonyAIBORobotSurface1', 'SonyAIBORobotSurface2',  'SyntheticControl',
 'Trace', 'TwoLeadECG',
 'Wine', 'WordSynonyms', 'Worms', 'WormsTwoClass']

for dataset_name in lists:
    print("-" * 10, dataset_name, "-" * 10)
    try:
        X_train, y_train, _, _ = ucr.load_dataset(dataset_name)
        print(X_train.shape, y_train.shape)
        X_train = X_train.astype(np.float64)

        dataset_losses = {}
        for display_name, metric, gamma in methods:
            for init in inits:
                loss, runtime = run_clustering(
                    X_train, y_train, method=metric, gamma=gamma, init_type=init
                )
                col_name = f"{display_name}_{init}"
                dataset_losses[col_name] = loss
                print(f"{dataset_name} — {display_name} + {init}: loss={loss:.3f}")

        all_results.append({"Dataset": dataset_name, **dataset_losses})

    except Exception as e:
        print(f"❌ {dataset_name} failed: {e}")

    # --- Convert to wide-format DataFrame ---
    df_wide = pd.DataFrame(all_results).set_index("Dataset")
    df_wide = df_wide.sort_index()

    # print("\nFinal Loss Table:")
    # print(df_wide)

    # Save results
    df_wide.to_csv("ucr_clustering_nsdtw_losses_.csv")
