import datetime as dt
import pandas as pd
from dateutil.relativedelta import relativedelta
import numpy as np


def datetime_totime(datetimes):
    '''
    This function simply conerts date time strings to just times.
    This can be useful when we want to plot histograms of average
    events on a daily basis.

    Arguments
    ---------
    datetimes: iterable
       This is an iterable list, array or series of strings
       of datetimes. This must be understandable by 
       dt.datetime.time().


    Returns
    --------
    times: pd.Series
        This is a series contaning the new times. They will all
        have date value of 1900-01-01.

    ''' 
    times = [dt.datetime.time(d) for d in datetimes]
    times = pd.to_datetime(times, format="%H:%M:%S")
    
    return times


def make_input_roll(data, sequence_length):
    '''
    This function will produce an array that is a rolled version of the original data sequence.
    For example, if the original dataset was [1,2,3,4,5] then with sequence_length = 2, this 
    function would return [[1,2],[2,3],[3,4],[4,5]].
    
    Arguments
    ---------
        data: numpy array
            This is the data that you want transformed. Please use the shape (n_datapoints, n_features).
        
        sequence_length: integer
            This is an integer that contains the length of each of the returned sequences.
            
    Returns:
    ---------
        output: array
            This is an array with the rolled data.
    
    '''
    if data.shape[0] < sequence_length + 1:
        if data.shape[1] > 0:
            return np.zeros((1,sequence_length, data.shape[1]), dtype = data.dtype)
        else: return np.zeros((1,sequence_length, 1), dtype = data.dtype)
    
    output = np.empty((data.shape[0]-sequence_length + 1,
                       sequence_length, 
                       data.shape[1]), dtype = data.dtype)
    
    for ns in range(sequence_length):
        endpoint = sequence_length - ns - 1
        
        if endpoint ==0:
            for n1 in range(data.shape[1]):
                output[:,ns,n1] = data[ns:,n1]


        else:
            for n1 in range(data.shape[1]):
                output[:,ns,n1] = data[ns:-endpoint,n1]

            
            
            
    return output


def data_interval_index(time_data, years = 0, months = 0, days = 0, hours = 0, minutes = 0, seconds = 0, 
                 start = 'start'):
    '''
    Given a data set that contains a field with dates, this function returns the indices of a section of the 
    original data within the time interval from the start date.
    
    Arguments
    ---------
        time_data: array
            This is the array that containers the times.
        years, months, weeks, days, hours, minutes, seconds: integers
            These are the values that are used for the time interval. For example, if years = 3. The
            time range for the returned dataset will be 3 years. You may use more than one of these arguments
            to build up more interesting time intervals. We refer to these arguments as time_arguments below.
        start: string or pandas timestamp instance.
            This is the start time of the returned dataset. All of the returned data will be within the interval
            [start, start+time_arguments] (unless the string 'end' is passed). 
            If you are supplying a string, it must be from the list ['start','end']. In the case of start, the
            interval will be the first time_arguments of the data. If 'end' is passed, it will be the most
            recent time_arguments of the data.
    
    Returns
    ---------
        output: array
            This is a boolean array that has element = True for the subset of the original data that has 
            data from the time interval given and False otherwise.
    
    '''

    # start and end dates for the data
    if start == 'start':
        start_date = min(time_data)
        end_date = start_date + relativedelta(years = years, 
                                              months = months, 
                                              days = days, 
                                              hours = hours, 
                                              minutes = minutes, 
                                              seconds = seconds)
        
    
    elif start == 'end':
        end_date = max(time_data)
        start_date = end_date - relativedelta(years = years, 
                                              months = months, 
                                              days = days, 
                                              hours = hours, 
                                              minutes = minutes, 
                                              seconds = seconds)
    
    else:
        start_date = start
        end_date = start_date + relativedelta(years = years, 
                                              months = months, 
                                              days = days, 
                                              hours = hours, 
                                              minutes = minutes, 
                                              seconds = seconds)
    
    # getting the subset of the data
    output = (time_data>=start_date) & (time_data<end_date)
    
    
    if type(output) == np.ndarray:
        return output

    return output.values


def data_interval_index_multiple_periods(time_data, time_period, exclusive = True):
    '''
    This groups the data by the time period given.
    
    Arguments
    ---------
    
        time_data: array
            This is the time data that you wish to convert to indexed groups.
        
        time_period: dictionary
            This is a dictionary containing the time period that you want to vectorise over. Please pass something similar
            to {'weeks' : 0, 'days' : 0, 'hours' : 0, 'minutes' : 0, 'seconds' : 0}. You do not need to pass
            all of the different values in the dictionary above, just the ones you require.
        
        exclusive: boolean
            This is a boolean value that dictates whether the groups should be mutually exclusive or not. If True,
            the groups will be mutually exclusive in their times (sometimes returning empty groups). 
            If False, a rolling window will be used to create the groups (always returning data in each group).
            
            
            
    Returns
    ---------
        output: list of boolean arrays
            Each element of the list are the indices of that element's time period.
    
    '''
    
    sorted_time  = np.sort(time_data)
    
    final_date = max(time_data) - pd.Timedelta(**time_period)
    
    if exclusive: end_date = min(time_data)
    else: 
        index = 0
        start_date = sorted_time[index]
        end_date = start_date + pd.Timedelta(**time_period)
    
    groups = []
    
    
    while end_date <= final_date:
        if exclusive:
            start_date = end_date
            end_date += pd.Timedelta(**time_period)
             
        indices = np.where((time_data >= start_date) & (time_data < end_date))[0]
        groups.append(indices)
        
        if not exclusive:
            index += 1
            start_date = sorted_time[index]
            end_date = start_date + pd.Timedelta(**time_period)
    
    
    
    return groups


 
def concatenate_data_times(data, time_data, time_period, return_times = True):
    '''
    This function will concatenate the data into vectors that represent a whole time period.
    For example, if the ininterval_period was 3 hours, and the time_period was 1 day, this function
    would transform the data from a shape of (N, n_features) to (N, (n_features x 8)).
    When the output array can not be filled with the data in the second dimension, 
    (for example if the intervals of the time_data does not exactly divide the time_period) a 
    padding of 0s will be used.
    
    Arguments
    ---------
        data: array
            The data that will be transformed
        
        time_data: array
            This is the dates and times associated with the data Arugment and should be compatible
            with the pd.to_datetime function
        
        time_period: dict
            This should be a dictionary of the form:
            {'weeks' : 0, 'days' : 0, 'hours' : 0, 'minutes' : 0, 'seconds' : 0}

        return_times: boolean
            This is True by default and decides whether the function should also output the first 
            time of each row along with the new data.

        
    Returns
    ---------
        output: array
            This is an array that contains the transformed data
    '''
    
    
    time_data = pd.to_datetime(time_data)
    
    groups = data_interval_index_multiple_periods(time_data, time_period)
    
    if len(groups) == 0:
        raise TypeError('There were no groups found! Check whether the input data was long enough to fill a single period of data.')
    
    n_shape = len(groups)
    
    f_shape = max([group.shape[0] for group in groups]) * data.shape[1]
    
    output = np.zeros((n_shape, f_shape))
    output_time = []

    for ng, group in enumerate(groups):
        if len(group) * data.shape[1] == f_shape:
            output[ng,:] = data[group,:].reshape(-1)
        
        # this will be used if the time data intervals don't exactly divide the time_period
        else:
            f_data_shape = len(group) * data.shape[1]
            output[ng,:(f_data_shape)] = data[group,:].reshape(-1)
        
        output_time.append(time_data[group[0]])
    
    if return_times:
        return output, output_time
    
    else:
        return output



def vectorised_movement_with_healthcare(subjects_list, container_movement, container_time, 
                                        uti_dict = {}, ha_dict = {}, 
                                        output_structure = 'population', months_train = 8, 
                                        concat_period = 24, time_section = 'both'):
    '''
    This function can be used to turn containers with subject vector data and times along
    with uti dates and hospital admission dates into a dataset that can be trained on.
    -1 refers to unlabelled data, 0 refers to negative cases and 1 refers to positive cases.
    
    Arguments
    ---------
        container_movement: numpy container or dictionary
            This is the container with the subjects as strings and an array of data as 
            the keys.
        
        container_time: numpy container or dictionary
            This is the container with the subjects as strings and the times corresponding
            to the data in the container_movement dictionary.
        
        uti_dict: dictionary
            This is a dictionary that contains the subjects names as keys and dictionaries
            as values. This next layer of dictionaries contains the keys ['datetimeRaised', 
            'datetimeClosed', 'valid'], which all have lists as their values. The values for 
            'datetimeRaised' and 'datetimeClosed' are lists containing pandas timestamps. The
            list for the value of 'valid' contains either True, False or nan referring to 
            whether the corresponding UTI was a true positive or a false positive detection.
        
        ha_dict: dictionary
            This is a dictionary that contains the subjects names as keys and lists as values.
            The values are timestamps and contain the dates of hospital admissions.
        
        output_structure: string
            This is a string that is either 'population' or 'single_subject' and dictates
            whether the output is a dictionary containing the data split by subject, or 
            if the population data is returned all together.
        
        concat_period: integer
            This is the time period in hours that is used to concatenate vectors. For example
            if the original data contained 3 hour windows, this argument allows you to 
            concatenate horizontally these vectors to return a vector that represents
            the concat_period.

        time_section: string
            This is a string from the list ['day', 'night', 'both'] and defines the part of 
            the data that is returned. If 'both' then the returned data will contain the 
            vectorised movement from the day and night. If 'day', then the returned data
            will contain the data from the day (6am-9pm) only. If 'night' then the returned data
            will contain the data from the night (9pm-6am) only.
        
    Returns
    ---------
        if output_structure == 'single_subject':
            subject_data: dictionary
                This is a dictionary of the form: 
                {subject_names: {'X_train', 'Y_train', 'X_test', 'Y_test'}}.
        
        if output_structure == 'population':
            (X_train, Y_train, X_test, Y_test): tuple of arrays
                This contains the data.
    
    '''
    
    
    if output_structure == 'single_subject':
        subject_data = {}
                              
    
    ns = 0
    for subject in subjects_list:

        subject_3_hour = container_movement[subject]
        subject_3_hour_times = container_time[subject]

        if subject_3_hour.shape[0] < concat_period/3:
            continue
        
        if time_section == 'day':
            
            times_timestamp = [pd.Timestamp(t) for t in subject_3_hour_times]
            index_day = ((datetime_totime(times_timestamp) > pd.Timestamp('1900-01-01T06:00:01')) 
                         & 
                         (datetime_totime(times_timestamp) < pd.Timestamp('1900-01-01T21:00:01')))
            subject_3_hour = subject_3_hour[index_day]
            subject_3_hour_times = subject_3_hour_times[index_day]
        
        elif time_section == 'night':
            
            times_timestamp = [pd.Timestamp(t) for t in subject_3_hour_times]
            index_night = ((datetime_totime(times_timestamp) < pd.Timestamp('1900-01-01T06:00:01')) 
                         |
                         (datetime_totime(times_timestamp) > pd.Timestamp('1900-01-01T21:00:01')))
            subject_3_hour = subject_3_hour[index_night]
            subject_3_hour_times = subject_3_hour_times[index_night]
            

        subject_X, times_X = concatenate_data_times(subject_3_hour, subject_3_hour_times, {'hours': concat_period})
        
        
        if output_structure == 'single_subject':
            time_interval = (times_X[-1].year - times_X[0].year)*12 + (times_X[-1].month - times_X[0].month)
            if time_interval <= months_train:
                continue

        time_index_uti_list = []
        time_index_no_uti_list = []

        # collecting uti data for the positive cases.
        if subject in uti_dict:


            uti_dates = uti_dict[subject]['datetimeRaised']
            uti_dates = [pd.to_datetime(uti_date).date() for uti_date in uti_dates]

            uti_validity_list = uti_dict[subject]['valid']

            days_side = 4

            for n_uti, uti_date in enumerate(uti_dates):

                validity = uti_validity_list[n_uti]

                uti_date = pd.Timestamp(uti_date)
                uti_start = uti_date - pd.Timedelta(days = days_side)
                uti_end = uti_date + pd.Timedelta(days = days_side)

                weeks_before = np.random.choice(np.arange(2,5))

                if validity:
                    no_uti_start = uti_start - pd.Timedelta(weeks = weeks_before)
                    no_uti_end = uti_end - pd.Timedelta(weeks = weeks_before)

                    time_index_uti = (np.asarray(times_X) >=  uti_start) & (np.asarray(times_X) <= uti_end)
                    time_index_no_uti = (np.asarray(times_X) >=  no_uti_start) & (np.asarray(times_X) <= no_uti_end)

                    time_index_uti_list.append(time_index_uti)
                    time_index_no_uti_list.append(time_index_no_uti)

                else:

                    time_index_no_uti = (np.asarray(times_X) >=  uti_start) & (np.asarray(times_X) <= uti_end)
                    time_index_no_uti_list.append(time_index_no_uti)


        else:
            time_index_uti = []

        time_index_ha_list = []
        time_index_no_ha_list = []

        # collecting ha data
        if subject in ha_dict:
            ha_dates = ha_dict[subject]
            ha_dates = [pd.to_datetime(ha_date).date() for ha_date in ha_dates]


            for ha_date in ha_dates:
                ha_date = pd.Timestamp(ha_date)
                ha_start = pd.Timestamp(ha_date.date())
                ha_end = pd.Timestamp(ha_date.date()) + pd.Timedelta(days = 1, seconds = 1)

                weeks_before = np.random.choice(np.arange(2,5))

                no_ha_start = ha_start - pd.Timedelta(weeks = weeks_before)
                no_ha_end = ha_end - pd.Timedelta(weeks = weeks_before)

                time_index_ha = (np.asarray(times_X) >=  ha_start) & (np.asarray(times_X) <= ha_end)
                time_index_no_ha = (np.asarray(times_X) >=  no_ha_start) & (np.asarray(times_X) <= no_ha_end)

                time_index_ha_list.append(time_index_ha)
                time_index_no_ha_list.append(time_index_no_ha)

        else:
            time_index_ha = []



        subject_Y = -1*np.ones(subject_X.shape[0]).reshape(-1,1)

        for han in range(len(time_index_ha_list)):
            subject_Y[time_index_ha_list[han]] = 1
            subject_Y[time_index_no_ha_list[han]] = -1

        for utin in range(len(time_index_uti_list)):
            subject_Y[time_index_uti_list[utin]] = 1
            subject_Y[time_index_no_uti_list[utin]] = 0

        train_index = data_interval_index(pd.to_datetime(times_X), months = months_train)
        test_index = ~train_index

        subject_X_train = subject_X[train_index]
        subject_X_test = subject_X[test_index]

        subject_Y_train = subject_Y[train_index]
        subject_Y_test = subject_Y[test_index]

        if output_structure == 'population':
            if ns == 0:
                X_train, X_test = subject_X_train, subject_X_test
                Y_train, Y_test = subject_Y_train, subject_Y_test

            else:
                X_train = np.vstack((X_train, subject_X_train))
                Y_train = np.vstack((Y_train, subject_Y_train))
                X_test = np.vstack((X_test, subject_X_test))
                Y_test = np.vstack((Y_test, subject_Y_test))

            ns += 1

        elif output_structure == 'single_subject':
            
            subject_data[subject] = {}
            subject_data[subject]['X_train'] = subject_X_train
            subject_data[subject]['X_test'] = subject_X_test
            subject_data[subject]['Y_train'] = np.ravel(subject_Y_train)
            subject_data[subject]['Y_test'] = np.ravel(subject_Y_test)
    
    if output_structure == 'population':
        Y_train = Y_train.ravel()
        Y_test = Y_test.ravel()
        
        return (X_train, Y_train, X_test, Y_test)
    
    elif output_structure == 'single_subject':
        
        return subject_data
