import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import warnings
from utils import sliding_window
from actdata.getdataloader_single import get_act_data
import glob
import re
from data_provider.uea import subsample, interpolate_missing, Normalizer
from sktime.datasets import load_from_tsfile_to_dataframe

warnings.filterwarnings('ignore')

class UCIloader(Dataset):
    def __init__(self, args, data_path, flag):
        self.data_path = data_path
        domain_0 = np.load(os.path.join(self.data_path, 'ucihar_domain_0_wd.data'), allow_pickle = True)
        domain_1 = np.load(os.path.join(self.data_path, 'ucihar_domain_1_wd.data'), allow_pickle = True)
        domain_2 = np.load(os.path.join(self.data_path, 'ucihar_domain_2_wd.data'), allow_pickle = True)
        domain_3 = np.load(os.path.join(self.data_path, 'ucihar_domain_3_wd.data'), allow_pickle = True)
        domain_4 = np.load(os.path.join(self.data_path, 'ucihar_domain_4_wd.data'), allow_pickle = True)
        domains = [domain_0, domain_1, domain_2, domain_3, domain_4]
        domain = args.target_domain
        test = domains[domain][0]
        # others will be train data
        # (347, 128, 9)
        # (347,) label
        # (347,) domain
        
        self.feature_df = test[0]
        self.labels_df = test[1]
        self.domain_df = test[2]

        self.max_seq_len = self.feature_df.shape[1]
        self.class_names = np.unique(self.labels_df)
        if flag == "TRAIN":
            train_list = [domains[i][0][0] for i in range(len(domains)) if i != domain]
            self.feature_df = np.concatenate(train_list, axis=0)
            train_label = [domains[i][0][1] for i in range(len(domains)) if i != domain]
            self.labels_df = np.concatenate(train_label, axis=0) 
            train_domain = [domains[i][0][2] for i in range(len(domains)) if i != domain]
            self.domain_df = np.concatenate(train_domain, axis=0)
        
        self.feature_df = torch.from_numpy(self.feature_df)
        self.labels_df = torch.from_numpy(self.labels_df)
        self.domain_df = torch.from_numpy(self.domain_df)
    
    def get_sample_weights(self):
        y = self.labels_df.numpy()
        unique_y, counts_y = np.unique(y, return_counts=True)
        weights = 1000.0 / torch.Tensor(counts_y)
        weights = weights.double()
        label_unique = np.unique(y)
        sample_weights = []
        for val in y:
            idx = np.where(label_unique == val)
            sample_weights.append(weights[idx])
        return sample_weights

    def normalize(self, x):
        mean = x.mean(dim=1, keepdim=True)
        std = x.std(dim=1, keepdim=True)
        x = (x - mean) / std
        return x

    def __len__(self):
        return len(self.feature_df)
    
    def __getitem__(self, index):
        # x = self.normalize(self.feature_df[index])
        return self.feature_df[index], self.labels_df[index], self.domain_df[index]


class SHARloader(Dataset):
    def __init__(self, args, data_path, flag):
        self.data_path = data_path
        domain_0 = np.load(os.path.join(self.data_path, 'shar_domain_1_wd.data'), allow_pickle = True)
        domain_1 = np.load(os.path.join(self.data_path, 'shar_domain_2_wd.data'), allow_pickle = True)
        domain_2 = np.load(os.path.join(self.data_path, 'shar_domain_3_wd.data'), allow_pickle = True)
        domain_3 = np.load(os.path.join(self.data_path, 'shar_domain_5_wd.data'), allow_pickle = True)
        domains = [domain_0, domain_1, domain_2, domain_3]
        domain = args.target_domain
        test = domains[domain][0]
        # others will be train data
        # (583, 453) -> (583, 453, 1)
        
        self.feature_df = test[0]
        self.labels_df = test[1]
        self.domain_df = test[2]

        if flag == "TRAIN":
            train_list = [domains[i][0][0] for i in range(len(domains)) if i != domain]
            self.feature_df = np.concatenate(train_list, axis=0)
            train_label = [domains[i][0][1] for i in range(len(domains)) if i != domain]
            self.labels_df = np.concatenate(train_label, axis=0)
            train_domain = [domains[i][0][2] for i in range(len(domains)) if i != domain]
            self.domain_df = np.concatenate(train_domain, axis=0)

        self.feature_df = self.feature_df.reshape(-1, 151, 3)
        self.feature_df = torch.from_numpy(self.feature_df)
        self.labels_df = torch.from_numpy(self.labels_df)
        self.domain_df = torch.from_numpy(self.domain_df)

        self.max_seq_len = self.feature_df.shape[1]
        self.class_names = np.unique(self.labels_df)

    def get_sample_weights(self):
        y = self.labels_df.numpy()
        unique_y, counts_y = np.unique(y, return_counts=True)
        weights = 1000.0 / torch.Tensor(counts_y)
        weights = weights.double()
        label_unique = np.unique(y)
        sample_weights = []
        for val in y:
            idx = np.where(label_unique == val)
            sample_weights.append(weights[idx])
        return sample_weights
    
    def __len__(self):
        return len(self.feature_df)
    
    def __getitem__(self, index):
        return self.feature_df[index], self.labels_df[index], self.domain_df[index]

class OPPloader(Dataset):
    def __init__(self, args, data_path, flag):
        self.data_path = data_path
        self.SLIDING_WINDOW_LEN = 30
        self.SLIDING_WINDOW_STEP = 15
        self.NUM_FEAFRUES = 77
        domain_0 = np.load(os.path.join(self.data_path,'oppor_domain_S1_wd.data'), allow_pickle=True)
        domain_1 = np.load(os.path.join(self.data_path,'oppor_domain_S2_wd.data'), allow_pickle=True)
        domain_2 = np.load(os.path.join(self.data_path, 'oppor_domain_S3_wd.data'), allow_pickle=True)
        domain_3 = np.load(os.path.join(self.data_path, 'oppor_domain_S4_wd.data'), allow_pickle=True)
        domains = [domain_0, domain_1, domain_2, domain_3]
        domain = args.target_domain
        test = domains[domain][0]
        x = test[0]
        y = test[1]
        d = test[2]
        x_win, y_win, d_win = self.opp_sliding_window_w_d(x,y,d,self.SLIDING_WINDOW_LEN, self.SLIDING_WINDOW_STEP)
        self.feature_df = x_win
        self.labels_df = y_win
        self.domain_df = d_win

        if flag == "TRAIN":
            # train_list = [domains[i][0][0] for i in range(len(domains)) if i != domain]
            # self.feature_df = np.concatenate(train_list, axis=0)
            # train_label = [domains[i][0][1] for i in range(len(domains)) if i != domain]
            # self.labels_df = np.concatenate(train_label, axis=0).astype(np.int64)
            # train_domain = [domains[i][0][2] for i in range(len(domains)) if i != domain]
            # self.domain_df = np.concatenate(train_domain, axis=0)
            train_list = []
            train_label = []
            train_domain = []
            for i in range(len(domains)):
                if i == args.target_domain:
                    continue
                x, y, d = domains[i][0][0], domains[i][0][1], domains[i][0][2]
                x_win, y_win, d_win = self.opp_sliding_window_w_d(x,y,d,self.SLIDING_WINDOW_LEN, self.SLIDING_WINDOW_STEP)
                train_list.append(x_win)
                train_label.append(y_win)
                train_domain.append(d_win)
            self.feature_df = np.concatenate(train_list, axis=0)
            self.labels_df = np.concatenate(train_label, axis=0)
            self.domain_df = np.concatenate(train_domain, axis=0)

        self.max_seq_len = self.feature_df.shape[1]
        self.class_names = np.unique(self.labels_df)

        self.feature_df = torch.from_numpy(self.feature_df)
        self.labels_df = torch.from_numpy(self.labels_df)
        self.domain_df = torch.from_numpy(self.domain_df)
    

    def opp_sliding_window_w_d(self, data_x, data_y, d, ws, ss): # window size, step size
        data_x = sliding_window.sliding_window(data_x,(ws,data_x.shape[1]),(ss,1))
        data_y = np.asarray([[i[-1]] for i in sliding_window.sliding_window(data_y,ws,ss)])
        data_d = np.asarray([[i[-1]] for i in sliding_window.sliding_window(d, ws, ss)])
        return data_x.astype(np.float32), data_y.reshape(len(data_y)).astype(np.uint8), data_d.reshape(len(data_d)).astype(np.uint8)

    def get_sample_weights(self):
        y = self.labels_df.numpy()
        unique_y, counts_y = np.unique(y, return_counts=True)
        weights = 1000.0 / torch.Tensor(counts_y)
        weights = weights.double()
        label_unique = np.unique(y)
        sample_weights = []
        for val in y:
            idx = np.where(label_unique == val)
            sample_weights.append(weights[idx])
        return sample_weights
    
    def __len__(self):
        return len(self.feature_df)
    
    def __getitem__(self, index):
        return self.feature_df[index], self.labels_df[index], self.domain_df[index]
    

class EMGloader(Dataset):
    def __init__(self, args, data_path, flag):
        self.data_path = data_path
        domain = args.target_domain
        domain_divide = [[i*9+j for j in range(9)]for i in range(4)]

        self.feature_df = None
        self.labels_df = None
        self.domain_df = None

        
        train, valid, test = get_act_data(args)

        if flag == "TRAIN":
            self.feature_df, self.labels_df, self.domain_df = train.x, train.labels, train.dlabels
        elif flag == "VAL":
            self.feature_df, self.labels_df, self.domain_df = valid.x, valid.labels, valid.dlabels
        elif flag == "TEST":
            self.feature_df, self.labels_df, self.domain_df = test.x, test.labels, test.dlabels

        # (1704, 8, 1, 200) -> (1704, 200, 8) permute
        self.feature_df = self.feature_df.squeeze(2).permute(0,2,1)

        self.max_seq_len = self.feature_df.shape[1]
        self.class_names = np.unique(self.labels_df)
        self.labels_df = torch.from_numpy(self.labels_df).long()
        self.domain_df = torch.from_numpy(self.domain_df).long()

    def __len__(self):
        return len(self.feature_df)
    
    def __getitem__(self, index):
        return self.feature_df[index], self.labels_df[index], self.domain_df[index]
    
class DSADSloader(Dataset):
    def __init__(self, args, data_path, flag):
        self.data_path = data_path
        domain_divide = [[0,1],[2,3],[4,5],[6,7]]

        self.feature_df = None
        self.labels_df = None
        self.domain_df = None

        train, valid, test = get_act_data(args)
        if flag == "TRAIN":
            self.feature_df, self.labels_df, self.domain_df = train.x, train.labels, train.dlabels
        elif flag == "VAL":
            self.feature_df, self.labels_df, self.domain_df = valid.x, valid.labels, valid.dlabels
        elif flag == "TEST":
            self.feature_df, self.labels_df, self.domain_df = test.x, test.labels, test.dlabels

        self.feature_df = self.feature_df.squeeze(2).permute(0,2,1)
        self.max_seq_len = self.feature_df.shape[1]
        self.class_names = np.unique(self.labels_df)

        self.labels_df = torch.from_numpy(self.labels_df).long()
        self.domain_df = torch.from_numpy(self.domain_df).long()

    def __len__(self):
        return len(self.feature_df)
    
    def __getitem__(self, index):
        return self.feature_df[index], self.labels_df[index], self.domain_df[index]

class USCHADloader(Dataset):
    def __init__(self, args, data_path, flag):
        self.data_path = data_path
        domain_divide = [[0,1,2,11],[3,5,6,9],[7,8,10,13],[4,12]]
        self.feature_df = None
        self.labels_df = None
        self.domain_df = None

        train, valid, test = get_act_data(args)
        if flag == "TRAIN":
            self.feature_df, self.labels_df, self.domain_df = train.x, train.labels, train.dlabels
        elif flag == "VAL":
            self.feature_df, self.labels_df, self.domain_df = valid.x, valid.labels, valid.dlabels
        elif flag == "TEST":
            self.feature_df, self.labels_df, self.domain_df = test.x, test.labels, test.dlabels
        
        self.feature_df = self.feature_df.squeeze(2).permute(0,2,1)
        self.max_seq_len = self.feature_df.shape[1]
        self.class_names = np.unique(self.labels_df)
        self.labels_df = torch.from_numpy(self.labels_df).long()
        self.domain_df = torch.from_numpy(self.domain_df).long()

    def __len__(self):
        return len(self.feature_df)
    
    def __getitem__(self, index):
        return self.feature_df[index], self.labels_df[index], self.domain_df[index]

class PAMAPloader(Dataset):
    def __init__(self, args, data_path, flag):
        self.data_path = data_path
        domain_divide = [[2,3,8],[1,5],[0,7],[4,6]]
        self.feature_df = None
        self.labels_df = None
        self.domain_df = None

        # (1704, 8, 200) -> (1704, 200, 8)
        train, valid, test = get_act_data(args)
        if flag == "TRAIN":
            self.feature_df, self.labels_df, self.domain_df = train.x, train.labels, train.dlabels
        elif flag == "VAL":
            self.feature_df, self.labels_df, self.domain_df = valid.x, valid.labels, valid.dlabels
        elif flag == "TEST":
            self.feature_df, self.labels_df, self.domain_df = test.x, test.labels, test.dlabels
        
        self.feature_df = self.feature_df.squeeze(2).permute(0,2,1)
        self.max_seq_len = self.feature_df.shape[1]
        # it's a hard code for PAMAP2 due to the diversify's domain division when target domain is 2, it will miss two class labels
        # but we keeped this to compare with other methods in the same setting in diversify's paper and it is a good setting to test the model when facing unseen class labels situation 
        self.class_names = np.array([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17])
        self.labels_df = torch.from_numpy(self.labels_df).long()
        self.domain_df = torch.from_numpy(self.domain_df).long()

    def __len__(self):
        return len(self.feature_df)
    
    def __getitem__(self, index):
        return self.feature_df[index], self.labels_df[index], self.domain_df[index]
class WESADloader(Dataset):
    def __init__(self, args, data_path, flag):
        self.data_path = data_path
        domain_divide = [[0,1,2,3],[4,5,6,7],[8,9,10,11],[12,13,14]]
        self.feature_df = None
        self.labels_df = None
        self.domain_df = None

        # (1704, 8, 200) -> (1704, 200, 8)
        train, valid, test = get_act_data(args)
        if flag == "TRAIN":
            self.feature_df, self.labels_df, self.domain_df = train.x, train.labels, train.dlabels
        elif flag == "VAL":
            self.feature_df, self.labels_df, self.domain_df = valid.x, valid.labels, valid.dlabels
        elif flag == "TEST":
            self.feature_df, self.labels_df, self.domain_df = test.x, test.labels, test.dlabels
        
        self.feature_df = self.feature_df.squeeze(2).permute(0,2,1)
        self.max_seq_len = self.feature_df.shape[1]

        self.class_names = np.unique(self.labels_df)
        self.labels_df = torch.from_numpy(self.labels_df).long()
        self.domain_df = torch.from_numpy(self.domain_df).long()

    def __len__(self):
        return len(self.feature_df)
    
    def __getitem__(self, index):
        return self.feature_df[index], self.labels_df[index], self.domain_df[index]

class EEGloader(Dataset):
    def __init__(self, args, data_path, flag):
        self.data_path = data_path
        domain_divide = [[0,1,2,3,4],[5,6,7,8,9],[10,11,12,13,14],[15,16,17,18,19]]
        self.feature_df = None
        self.labels_df = None
        self.domain_df = None

        # (1704, 8, 200) -> (1704, 200, 8)
        train, valid, test = get_act_data(args)
        if flag == "TRAIN":
            self.feature_df, self.labels_df, self.domain_df = train.x, train.labels, train.dlabels
        elif flag == "VAL":
            self.feature_df, self.labels_df, self.domain_df = valid.x, valid.labels, valid.dlabels
        elif flag == "TEST":
            self.feature_df, self.labels_df, self.domain_df = test.x, test.labels, test.dlabels
        
        self.feature_df = self.feature_df.squeeze(2).permute(0,2,1)
        self.max_seq_len = self.feature_df.shape[1]

        self.class_names = np.unique(self.labels_df)
        self.labels_df = torch.from_numpy(self.labels_df).long()
        self.domain_df = torch.from_numpy(self.domain_df).long()

    def __len__(self):
        return len(self.feature_df)
    
    def __getitem__(self, index):
        return self.feature_df[index], self.labels_df[index], self.domain_df[index]
    


class UEAloader(Dataset):
    """
    Dataset class for datasets included in:
        Time Series Classification Archive (www.timeseriesclassification.com)
    Argument:
        limit_size: float in (0, 1) for debug
    Attributes:
        all_df: (num_samples * seq_len, num_columns) dataframe indexed by integer indices, with multiple rows corresponding to the same index (sample).
            Each row is a time step; Each column contains either metadata (e.g. timestamp) or a feature.
        feature_df: (num_samples * seq_len, feat_dim) dataframe; contains the subset of columns of `all_df` which correspond to selected features
        feature_names: names of columns contained in `feature_df` (same as feature_df.columns)
        all_IDs: (num_samples,) series of IDs contained in `all_df`/`feature_df` (same as all_df.index.unique() )
        labels_df: (num_samples, num_labels) pd.DataFrame of label(s) for each sample
        max_seq_len: maximum sequence (time series) length. If None, script argument `max_seq_len` will be used.
            (Moreover, script argument overrides this attribute)
    """

    def __init__(self, args, data_path, file_list=None, limit_size=None, flag=None):
        self.args = args
        self.root_path = data_path
        self.flag = flag
        print(data_path)
        self.all_df, self.labels_df = self.load_all(data_path, file_list=file_list, flag=flag)
        self.all_IDs = self.all_df.index.unique()  # all sample IDs (integer indices 0 ... num_samples-1)

        if limit_size is not None:
            if limit_size > 1:
                limit_size = int(limit_size)
            else:  # interpret as proportion if in (0, 1]
                limit_size = int(limit_size * len(self.all_IDs))
            self.all_IDs = self.all_IDs[:limit_size]
            self.all_df = self.all_df.loc[self.all_IDs]

        # use all features
        self.feature_names = self.all_df.columns
        self.feature_df = self.all_df

        # pre_process
        normalizer = Normalizer()
        self.feature_df = normalizer.normalize(self.feature_df)
        print(len(self.all_IDs))

    def load_all(self, root_path, file_list=None, flag=None):
        """
        Loads datasets from csv files contained in `root_path` into a dataframe, optionally choosing from `pattern`
        Args:
            root_path: directory containing all individual .csv files
            file_list: optionally, provide a list of file paths within `root_path` to consider.
                Otherwise, entire `root_path` contents will be used.
        Returns:
            all_df: a single (possibly concatenated) dataframe with all data corresponding to specified files
            labels_df: dataframe containing label(s) for each sample
        """
        # Select paths for training and evaluation
        if file_list is None:
            data_paths = glob.glob(os.path.join(root_path, '*'))  # list of all paths
        else:
            data_paths = [os.path.join(root_path, p) for p in file_list]
        if len(data_paths) == 0:
            raise Exception('No files found using: {}'.format(os.path.join(root_path, '*')))
        if flag is not None:
            data_paths = list(filter(lambda x: re.search(flag, x), data_paths))
        input_paths = [p for p in data_paths if os.path.isfile(p) and p.endswith('.ts')]
        if len(input_paths) == 0:
            pattern='*.ts'
            raise Exception("No .ts files found using pattern: '{}'".format(pattern))

        all_df, labels_df = self.load_single(input_paths[0])  # a single file contains dataset

        return all_df, labels_df

    def load_single(self, filepath):
        df, labels = load_from_tsfile_to_dataframe(filepath, return_separate_X_and_y=True,
                                                             replace_missing_vals_with='NaN')
        labels = pd.Series(labels, dtype="category")
        self.class_names = labels.cat.categories
        labels_df = pd.DataFrame(labels.cat.codes,
                                 dtype=np.int8)  # int8-32 gives an error when using nn.CrossEntropyLoss

        lengths = df.applymap(
            lambda x: len(x)).values  # (num_samples, num_dimensions) array containing the length of each series

        horiz_diffs = np.abs(lengths - np.expand_dims(lengths[:, 0], -1))

        if np.sum(horiz_diffs) > 0:  # if any row (sample) has varying length across dimensions
            df = df.applymap(subsample)

        lengths = df.applymap(lambda x: len(x)).values
        vert_diffs = np.abs(lengths - np.expand_dims(lengths[0, :], 0))
        if np.sum(vert_diffs) > 0:  # if any column (dimension) has varying length across samples
            self.max_seq_len = int(np.max(lengths[:, 0]))
        else:
            self.max_seq_len = lengths[0, 0]

        # First create a (seq_len, feat_dim) dataframe for each sample, indexed by a single integer ("ID" of the sample)
        # Then concatenate into a (num_samples * seq_len, feat_dim) dataframe, with multiple rows corresponding to the
        # sample index (i.e. the same scheme as all datasets in this project)

        df = pd.concat((pd.DataFrame({col: df.loc[row, col] for col in df.columns}).reset_index(drop=True).set_index(
            pd.Series(lengths[row, 0] * [row])) for row in range(df.shape[0])), axis=0)

        # Replace NaN values
        grp = df.groupby(by=df.index)
        df = grp.transform(interpolate_missing)

        return df, labels_df

    def instance_norm(self, case):
        if self.root_path.count('EthanolConcentration') > 0:  # special process for numerical stability
            mean = case.mean(0, keepdim=True)
            case = case - mean
            stdev = torch.sqrt(torch.var(case, dim=1, keepdim=True, unbiased=False) + 1e-5)
            case /= stdev
            return case
        else:
            return case

    def __getitem__(self, ind):
        batch_x = self.feature_df.loc[self.all_IDs[ind]].values
        labels = self.labels_df.loc[self.all_IDs[ind]].values
        if self.flag == "TRAIN" and self.args.augmentation_ratio > 0:
            num_samples = len(self.all_IDs)
            num_columns = self.feature_df.shape[1]
            seq_len = int(self.feature_df.shape[0] / num_samples)
            batch_x = batch_x.reshape((1, seq_len, num_columns))
            batch_x, labels, augmentation_tags = run_augmentation_single(batch_x, labels, self.args)

            batch_x = batch_x.reshape((1 * seq_len, num_columns))

        return self.instance_norm(torch.from_numpy(batch_x)), \
               torch.from_numpy(labels), torch.zeros(1).int()

    def __len__(self):
        return len(self.all_IDs)