import os
import sys
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.multiprocessing import Pool

from data_provider.uea import collate_fn

sys.path.append('../../')
sys.path.append('../../pipeline')
print([os.path.abspath(p) for p in sys.path])
from pipeline.ca_database_api import DataHandler


class CLDataSet(Dataset):
    def __init__(
            self,
            args,
    ):
        print("Loading the dataset...")
        self.data_handler = DataHandler(
            database_save_dir=args.database_save_dir,
            data_name=args.data_name,
            exp_id=args.exp_id,
            patient_list=args.patient_list,
            noise_ratio=args.noise_ratio,
            window_time=args.window_time,
            slide_time=args.slide_time,
            num_level=args.num_level,
        )
        data_pack = self.data_handler.get_data(segment=False, model_label=args.model_label)
        seg_data_pack = self.data_handler.get_segment_data(data_pack)

        self.model = args.model

        if self.model == 'TimesNet':
            self.data = torch.tensor(seg_data_pack.data, dtype=torch.float)
            # data.size(): (num_level * seg_big_num * seg_small_num) x length x n_features
            self.data = self.data.view(-1, *self.data.size()[-2:]).permute(0, 2, 1)
            self.label = torch.tensor(seg_data_pack.label, dtype=torch.long)
            # label.size(): num_level * seg_big_num * seg_small_num
            self.label = self.label.view(-1)
        elif self.model == 'PatchTST':
            self.data = torch.tensor(data_pack.data, dtype=torch.float)
            # data.size(): (num_level * seg_big_num) x length x n_features
            self.data = self.data.view(-1, *self.data.size()[-2:])
            self.label = torch.tensor(seg_data_pack.label, dtype=torch.long)
            # label.size(): (num_level * seg_big_num) x seg_small_num
            self.label = self.label.view(-1, self.label.size(-1))
        else:
            raise ValueError('Please enter correct model name.')

        del data_pack, seg_data_pack

        self.class_name = self.label.unique()
        self.n_class = len(self.class_name)

        # label.size(): (num_level * seg_big_num * seg_small_num) x n_class
        one_hot_label = F.one_hot(self.label.view(-1), num_classes=self.n_class)
        self.sample_ratio = one_hot_label.float().mean(dim=0)
        print('Class weight is:', self.sample_ratio)

        self.patient_list = args.patient_list
        self.max_len = self.data.size(-2)
        self.n_features = self.data.size(-1)
        print('Number of features:', self.n_features)
        self.nProcessLoader = args.n_process_loader
        self.reload_pool = Pool(args.n_process_loader)
        self.num_level = args.num_level

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def get_data_loader(self, batch_size, shuffle=False, num_workers=0):
        return DataLoader(
            self,
            batch_size=batch_size,
            num_workers=num_workers,
            shuffle=shuffle,
            drop_last=False,
            collate_fn=lambda x: collate_fn(x, max_len=self.max_len),
        )
