import numpy as np
import pandas as pd
from aeon.datasets import load_classification
import sys
sys.path.append("../../")
import os
from joblib import Parallel, delayed
import argparse
from pathlib import Path
import json
from tqdm import tqdm
from functools import partial
import datetime

from sklearn.preprocessing import OrdinalEncoder
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import f1_score, accuracy_score
from sklearn.neighbors import KNeighborsClassifier

from src.model import primal_fit_to
from src.utils import primal_left
from src.representation import augment, polynomial_feature_map
from src.numpy_metric import hs_metric,operator_metric,eigenvalue_metric,subspace_metric, chordal_metric, martin_metric





#####################################################################################################################################
# PARAMETERS #
#####################################################################################################################################

DATASETS = [
    "AtrialFibrillation",
    "BasicMotions",
    "Cricket",
    "EigenWorms",
    "Epilepsy",
    "ERing",
    "FingerMovements",
    "HandMovementDirection", 
    "Handwriting",  
    "Heartbeat",
    "NATOPS",  
    "SelfRegulationSCP1", 
    "StandWalkJump",
    "UWaveGestureLibrary",
]

ALPHAS = [0.01,0.1,0.5,0.9,0.99]

PLOT_PROGRESS = True



#####################################################################################################################################
# UTILITIES #
#####################################################################################################################################

def filtering(D,R,L,threshold=1e-2):
    mask = np.abs(np.exp(D))>threshold
    D_f = D[mask]
    R_f = R[:,mask]
    L_f = L[:,mask]
    return D_f,R_f,L_f

def set_parameters(n_samples,sampling_ratio = 0.2,max_sampfreq=100,max_context_window=50): 
    sampfreq = min(max_sampfreq, int((n_samples//2)*sampling_ratio))
    context_window = min(max_context_window, n_samples//2)
    return sampfreq, context_window

def metric_cross_matrix(
    metric_func: callable,
    Ts_lst: list,
    Tt_lst: list = None,
    n_jobs: int = 1,
    metric_kwargs: dict = {}
    ) -> np.ndarray:

    def compute(i, j):
            val = metric_func(*Ts_lst[i], *Ts_lst[j], **metric_kwargs)
            return (i, j, val)
    
    if Tt_lst is None:
        Tt_lst = Ts_lst
        n = len(Ts_lst)
        idxs = np.vstack([np.triu_indices(n)]).T
        results = Parallel(n_jobs=n_jobs)(
            delayed(compute)(i, j) for i, j in tqdm(idxs, disable=not PLOT_PROGRESS)
        )
        results = np.array(results)
        mat = np.zeros((n, n))
        mat[results[:,0].astype(int), results[:,1].astype(int)] = results[:,2]
        mat = mat + mat.T - np.diag(mat.diagonal())
        return mat
    else:
        m, n = len(Ts_lst), len(Tt_lst)
        idxs = np.array([[i, j] for i in range(m) for j in range(n)])
        results = Parallel(n_jobs=n_jobs)(
            delayed(compute)(i, j) for i, j in tqdm(idxs, disable=not PLOT_PROGRESS)
        )
        results = np.array(results)
        mat = np.zeros((m, n))
        mat[results[:,0].astype(int), results[:,1].astype(int)] = results[:,2]
        return mat


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="One fold classification experiment")

    parser.add_argument("--exp_id", type=int, required=True, help="Experiment id")
    parser.add_argument("--fold_id", type=str, required=True, help="Fold id")
    parser.add_argument("--seed", type=int, required=True, help="Random seed for reproducibility")

    parser.add_argument("--save_folder", type=str, default="results", help="Folder to save results")
    parser.add_argument("--n_jobs", type=int, default=1, help="Number of parallel jobs")
    parser.add_argument("--test_size", type=float, default=0.3, help="Proportion of the dataset to include in the test split")
    parser.add_argument("--sampling_ratio", type=float, default=0.2, help="Sampling ratio to determine the sampling frequency and context window")
    parser.add_argument("--max_sampfreq", type=int, default=100, help="Maximum sampling frequency")
    parser.add_argument("--max_context_window", type=int, default=50, help="Maximum context window")
    parser.add_argument("--eigen_tol", type=float, default=1e-2, help="Tolerance for eigenvalue filtering")
    parser.add_argument("--max_rank", type=int, default=10, help="Maximum rank for the spectral decomposition")
    parser.add_argument("--poly_order", type=int, default=1, help="Polynomial order for the spectral decomposition")
    parser.add_argument("--tikhonov_reg", type=float, default=1e-6, help="Tikhonov regularization parameter")

    args = parser.parse_args()

    # set seed
    np.random.seed(args.seed)

    # create the score file
    score_file = Path(args.save_folder) / f"exp_{args.exp_id}" / f"scores_fold_{args.fold_id}.csv"
    log_file = Path(args.save_folder) / f"exp_{args.exp_id}" / f"log_fold_{args.fold_id}.txt"

    n_datasets = len(DATASETS)
    for i, dataset_name in enumerate(DATASETS):
        print(f"Processing dataset {i+1}/{n_datasets}: {dataset_name}")
        #loading data
        try:
            X, y = load_classification(dataset_name)
        except:
            print(f"Failed to load dataset {dataset_name}. Skipping...")
            with open(log_file, 'a') as f:
                f.write(f"Failed to load dataset {dataset_name}. Skipping...\n")
            continue
        X = np.swapaxes(X, 1, 2)
        ord = OrdinalEncoder()
        y = ord.fit_transform(y.reshape(-1,1))
        n_ts,n_samples,n_d = X.shape
        train_idxs,test_idxs = train_test_split(np.arange(n_ts),test_size=args.test_size,random_state=args.seed,shuffle=True)
        sampfreq, context_window = set_parameters(n_samples, args.sampling_ratio, args.max_sampfreq, args.max_context_window)

        def compute_T(X): 
            X_temp = augment(X, context_window)
            Z = polynomial_feature_map(X_temp, order=args.poly_order)
            e = primal_fit_to(Z,1/sampfreq,tikhonov_reg=args.tikhonov_reg,rank=args.max_rank,symmetry=None)
            D,R,L = e["values"],e["right"],primal_left(e,Z)
            return filtering(D,R,L,threshold=args.eigen_tol)

        print("Computing spectral decompositions...")
        T_lst = Parallel(n_jobs=args.n_jobs)(delayed(compute_T)(x_train) for x_train in tqdm(X, disable=not PLOT_PROGRESS))
        T_lst = np.array(T_lst,dtype=object)

        # setting metrics
        metrics = [
            partial(hs_metric,sampfreqs=sampfreq,sampfreqt=sampfreq),
            partial(operator_metric,sampfreqs=sampfreq,sampfreqt=sampfreq),
            partial(eigenvalue_metric,sampfreqs=sampfreq,sampfreqt=sampfreq),
            partial(subspace_metric,sampfreqs=sampfreq,sampfreqt=sampfreq),
            partial(martin_metric,sampfreqs=sampfreq,sampfreqt=sampfreq)
        ]
        metrics += [partial(chordal_metric, alpha=a) for a in ALPHAS]

        metrics_names = [
            "Hilbert-Schmidt",
            "Operator",
            "SOT",
            "GOT",
            "Martin"
        ]
        metrics_names += [f"Chordal_{int(a*100)}" for a in ALPHAS]
        n_metrics = len(metrics)

        for j, (metric_func, metric_name) in enumerate(zip(metrics, metrics_names)):
            #check if the score has already been computed
            if score_file.is_file():
                df = pd.read_csv(score_file)
                if ((df["dataset"]==dataset_name) & (df["metric"]==metric_name)).any():
                    print(f"Scores for dataset {dataset_name} with metric {metric_name} already computed. Skipping...")
                    continue
            
            # Computing scores for the metric
            print(f"Processing {dataset_name} ({i+1}/{n_datasets}) -- {metric_name} ({j+1}/{n_metrics})")
            try:
                start_time = datetime.datetime.now()
                D = metric_cross_matrix(metric_func, T_lst)
                param_grid = {'n_neighbors': list(range(1, 11))}
                knn = KNeighborsClassifier(metric="precomputed")
                grid = GridSearchCV(knn, param_grid, cv=5, n_jobs=args.n_jobs)
                grid.fit(D[train_idxs,:][:,train_idxs], y[train_idxs].ravel())
                best_k = grid.best_params_['n_neighbors']
                knn = grid.best_estimator_
                pred = knn.predict(D[test_idxs,:][:,train_idxs])
                acc = accuracy_score(y[test_idxs], pred)
                f1_weighted = f1_score(y[test_idxs], pred, average='weighted')
                f1_macro = f1_score(y[test_idxs], pred, average='macro')
                end_time = datetime.datetime.now()
                compute_time = end_time - start_time
                dct = {"dataset": dataset_name, "metric": metric_name, "best_k": best_k ,"validation_accuracy": grid.best_score_, "accuracy": acc, "f1_weighted": f1_weighted, "f1_macro": f1_macro, "compute_time": compute_time.total_seconds()}
                if not score_file.is_file():
                    df = pd.DataFrame([dct])
                else:
                    df = pd.read_csv(score_file)
                    df = pd.concat([df, pd.DataFrame([dct])], ignore_index=True)
                df.to_csv(score_file, index=False)
                print(f"{dataset_name} ({i+1}/{n_datasets}) -- {metric_name} ({j+1}/{n_metrics}) -- accuracy: {acc:.2f}, f1_weighted: {f1_weighted:.2f}, f1_macro: {f1_macro:.2f}, compute_time: {compute_time.total_seconds():.2f}")
            except: 
                print(f"An error occurred while processing {dataset_name} with metric {metric_name}. Skipping...")
                with open(log_file, 'a') as f:
                    f.write(f"An error occurred while processing {dataset_name} with metric {metric_name}. Skipping...\n")
                continue
    print("All experiments completed.")






