from sktime.classification.interval_based import TimeSeriesForestClassifier
from sktime.transformations.panel.rocket import MiniRocket, MiniRocketMultivariate
from sklearn.linear_model import LogisticRegression
import joblib

import numpy as np
import os
import sys
sys.path.append('../../')
sys.path.append('../../pipeline')

from pipeline.ca_database_api import DataHandler


def sktime_dataset(args):
    print("Loading the training dataset...")
    train_data_handler = DataHandler(
        database_save_dir=args.database_save_dir,
        data_name=args.data_name,
        exp_id=args.exp_id,
        patient_list=args.train_patient_list,
        noise_ratio=args.noise_ratio,
        window_time=args.window_time,
        slide_time=args.slide_time,
        num_level=args.num_level,
    )
    data_pack = train_data_handler.get_data(model_label=args.model_label)
    x_train = data_pack.data.reshape(-1, *data_pack.data.shape[-2:])
    y_train = data_pack.label.reshape(-1)
    del data_pack, train_data_handler

    print("Loading the testing dataset...")
    test_data_handler = DataHandler(
        database_save_dir=args.database_save_dir,
        data_name=args.data_name,
        exp_id=args.exp_id,
        patient_list=args.test_patient_list,
        noise_ratio=0,
        window_time=args.window_time,
        slide_time=args.slide_time,
        num_level=args.num_level,
    )
    data_pack = test_data_handler.get_data()
    x_test = data_pack.data.reshape(-1, *data_pack.data.shape[-2:])
    y_test = data_pack.label.reshape(-1)
    del data_pack
    n_class = len(np.unique(y_test))

    return test_data_handler, x_train, y_train, x_test, y_test, n_class


def sktime_model(
        model_name,
        x_train,
        y_train,
        x_test,
        n_jobs,
        load_dir,
        save_dir,
):
    assert model_name in ['tsf', 'minirocket']
    if load_dir is not None:
        print(f'Loading the model from: {load_dir}')
        model = joblib.load(os.path.join(load_dir, 'model.pkl'))
    else:
        model = None

    if model_name == 'tsf':
        if model is None:
            model = TimeSeriesForestClassifier(n_jobs=n_jobs)

            print('-' * 10, model_name + ' Training starting', '-' * 10)
            model.fit(x_train, y_train)

        del x_train, y_train
        print('-' * 10, model_name + ' Testing starting', '-' * 10)
        y_pred = model.predict_proba(x_test)
        del x_test
        y_pred = np.argmax(y_pred, axis=-1)

        if save_dir is not None:
            joblib.dump(model, os.path.join(save_dir, 'model.pkl'), compress=3)
            print(f'Saving the model to {save_dir} done')

        return y_pred
    elif model_name == 'minirocket':
        if model is None:
            if x_train.shape[1] == 1:
                mini_rocket = MiniRocket(n_jobs=n_jobs)
            else:
                mini_rocket = MiniRocketMultivariate(n_jobs=n_jobs)
            model = LogisticRegression(max_iter=1000)

            mini_rocket.fit(x_train)

            tmp_x_train = []
            start, end = 0, 50000
            while True:
                if end < x_train.shape[0]:
                    tmp_x_train.append(mini_rocket.transform(x_train[start:end]))
                else:
                    tmp_x_train.append(mini_rocket.transform(x_train[start:]))
                    break
                start = end
                end += 50000
            x_train = np.concatenate(tmp_x_train, axis=0)
            del tmp_x_train

            print('-' * 10, model_name + ' Training starting', '-' * 10)
            model.fit(x_train, y_train)
        else:
            mini_rocket = joblib.load(os.path.join(load_dir, 'mini_rocket.pkl'))

        del x_train, y_train
        print('-' * 10, model_name + ' Testing starting', '-' * 10)
        y_pred = []
        start, end = 0, 50000
        while True:
            if end < x_test.shape[0]:
                y_pred.append(model.predict_proba(mini_rocket.transform(x_test[start:end])))
            else:
                y_pred.append(model.predict_proba(mini_rocket.transform(x_test[start:])))
                break
            start = end
            end += 50000
        del x_test
        y_pred = np.concatenate(y_pred, axis=0)
        y_pred = np.argmax(y_pred, axis=-1)

        if save_dir is not None:
            joblib.dump(mini_rocket, os.path.join(save_dir, 'mini_rocket.pkl'), compress=3)
            joblib.dump(model, os.path.join(save_dir, 'model.pkl'), compress=3)
            print(f'Saving the model to {save_dir} done')

        return y_pred
