import numpy as np
import sys
import torch
import torch.nn.functional as F

sys.path.append('../../../')
sys.path.append('../../../pipeline')
sys.path.append('../../../utils')

from pipeline.ca_database_api import DataHandler


def seeg_datasets(args):
    print("Loading the training dataset...")
    train_data_handler = DataHandler(
        database_save_dir=args.database_save_dir,
        data_name=args.data_name,
        exp_id=args.exp_id,
        patient_list=args.train_patient_list,
        noise_ratio=args.noise_ratio,
        window_time=args.window_time,
        slide_time=args.slide_time,
        num_level=args.num_level,
    )
    train_data_pack = train_data_handler.get_data()
    x_train = train_data_pack.data.reshape(-1, *train_data_pack.data.shape[-2:])
    y_train = train_data_pack.label.reshape(-1)
    n_class = len(np.unique(y_train))
    train_label = torch.tensor(train_data_pack.label, dtype=torch.long)
    train_label = F.one_hot(train_label, num_classes=n_class)
    del train_data_pack

    print("Loading the validing dataset...")
    valid_data_handler = DataHandler(
        database_save_dir=args.database_save_dir,
        data_name=args.data_name,
        exp_id=args.exp_id,
        patient_list=args.valid_patient_list,
        noise_ratio=args.noise_ratio,
        window_time=args.window_time,
        slide_time=args.slide_time,
        num_level=args.num_level,
    )
    valid_data_pack = valid_data_handler.get_data()
    x_valid = valid_data_pack.data.reshape(-1, *valid_data_pack.data.shape[-2:])
    y_valid = valid_data_pack.label.reshape(-1)
    del valid_data_pack, valid_data_handler


    print("Loading the testing dataset...")
    test_data_handler = DataHandler(
        database_save_dir=args.database_save_dir,
        data_name=args.data_name,
        exp_id=args.exp_id,
        patient_list=args.test_patient_list,
        noise_ratio=0,
        window_time=args.window_time,
        slide_time=args.slide_time,
        num_level=args.num_level,
    )
    test_data_pack = test_data_handler.get_data()
    x_test = test_data_pack.data.reshape(-1, *test_data_pack.data.shape[-2:])
    y_test = test_data_pack.label.reshape(-1)
    del test_data_pack, test_data_handler
    n_class = len(np.unique(y_test))

    return train_data_handler, train_label, x_train, y_train, x_valid, y_valid, x_test, y_test, n_class


def test_datasets(args):
    print("Loading the testing dataset...")
    test_data_handler = DataHandler(
        database_save_dir=args.database_save_dir,
        data_name=args.data_name,
        exp_id=args.exp_id,
        patient_list=args.test_patient_list,
        noise_ratio=0,
        window_time=args.window_time,
        slide_time=args.slide_time,
        num_level=args.num_level,
    )
    test_data_pack = test_data_handler.get_data()
    
    print("y_test.shape:", test_data_pack.label.shape)
    
    x_test = test_data_pack.data.reshape(-1, *test_data_pack.data.shape[-2:])
    y_test = test_data_pack.label.reshape(-1)

    del test_data_pack, test_data_handler
    n_class = len(np.unique(y_test))

    return x_test, y_test, n_class