import os
import logging
import numpy as np
import pandas as pd
import datetime as dt


class ChainedAssignent:
    def __init__(self, chained=None):
        acceptable = [None, 'warn', 'raise']
        assert chained in acceptable, "chained must be in " + str(acceptable)
        self.swcw = chained

    def __enter__(self):
        self.saved_swcw = pd.options.mode.chained_assignment
        pd.options.mode.chained_assignment = self.swcw
        return self

    def __exit__(self, *args):
        pd.options.mode.chained_assignment = self.saved_swcw


def get_environment_path():
    # NOTE: SPECIFY YOUR PATH TO THE VIRTUAL ENVIRONMENT!
    venv_name = "TMLR_event-triggered-time-varying-bayesian-optimization"
    path = os.getcwd()
    start_idx = path.rfind(venv_name)
    path = path[:start_idx + len(venv_name) + 1]
    if not path[-1] == "/":
        path += "/"
    return path


def preprocess_temperature_data(train_days=[3, 7], test_days=[7, 9], verbose=False, return_unscaled=False):
    """For a visual exploration of the temperature data check out the jupyter-notebook."""
    if verbose:
        print("Start preprocessing temperature data...")

    try:
        path = get_environment_path()
        df = pd.read_csv(path + 'objective_functions/applications/temperature_data.txt', sep=" ", header=None)
    except:
        raise Exception("Make sure, you downloaded the temperature data "
                        "from http://db.csail.mit.edu/labdata/labdata.html, named it 'temperature_data.txt',"
                        "and copied it into .../objective_functions/applications/.")

    df.columns = ["date", "time", "epoch", "moteid", "temperature", "humidity", "light", "voltage"]
    df = df.drop(columns=["humidity", "light", "voltage"])

    # delete entries not assigned to one mote
    df = df[df["moteid"].notna()]
    df["moteid"] = df["moteid"].astype(int)

    n_sensors = 54
    all_sensor_data = []
    for sensor in range(1, n_sensors + 1):
        sensor_data = df[df["moteid"] == sensor]
        all_sensor_data.append(sensor_data)

    # filter out sensors with not enough measurements
    data_missing = {}
    sensor_data_ts = {}
    for i, sensor_data in enumerate(all_sensor_data, 1):
        # convert to date time
        deleted = ""
        with ChainedAssignent():
            sensor_data['DateTime'] = pd.to_datetime(sensor_data['date'].apply(str) + ' ' + sensor_data['time'])
        sensor_data = sensor_data.drop(columns=['date', 'time'])
        ts = sensor_data.set_index('DateTime').sort_values(['DateTime'])
        nr_measurements = len(ts['temperature'])

        # filter temterature
        ts = ts[ts['temperature'] < 50]

        if nr_measurements > 30000:
            sensor_data_ts[str(i)] = ts
            missing_measurements = ts['temperature'].isnull().sum()
            if missing_measurements > 0:
                data_missing[str(i)] = missing_measurements
        else:
            deleted = "DELETED"

        if verbose:
            print(f"Sensor ID:{i}, Temperature measurements: {nr_measurements} \t {deleted}")

    # slice only a subset of days
    sliced_sensor_data = {}
    for sensor in sensor_data_ts.keys():
        start = sensor_data_ts[sensor].index.searchsorted(dt.datetime(2004, 3, min(train_days[0], test_days[0])))
        end = sensor_data_ts[sensor].index.searchsorted(dt.datetime(2004, 3, max(train_days[1], test_days[1])))
        sliced_df = sensor_data_ts[sensor].iloc[start:end]
        sliced_sensor_data[sensor] = sliced_df

    # remove sensor 18 -> see data exploration in objective_functions/applications for details
    sliced_sensor_data.pop("18")

    # get initial data for the hyperparameter
    initial_sensor_data = {}
    unseen_sensor_data = {}
    for sensor in sliced_sensor_data.keys():
        start = sliced_sensor_data[sensor].index.searchsorted(dt.datetime(2004, 3, train_days[0]))
        end = sliced_sensor_data[sensor].index.searchsorted(dt.datetime(2004, 3, train_days[1]))
        sliced_df = sliced_sensor_data[sensor].iloc[start:end].copy()
        initial_sensor_data[sensor] = sliced_df

        start2 = sliced_sensor_data[sensor].index.searchsorted(dt.datetime(2004, 3, test_days[0]))
        end2 = sliced_sensor_data[sensor].index.searchsorted(dt.datetime(2004, 3, test_days[1]))
        sliced_df = sliced_sensor_data[sensor].iloc[start2:end2].copy()
        unseen_sensor_data[sensor] = sliced_df

    if return_unscaled:
        return sliced_sensor_data

    # normalize to mean 0 and stdv 1 using the first 5 days
    all_temperature = np.empty(0)
    for sensor in initial_sensor_data.keys():
        tmp = np.asarray(initial_sensor_data[sensor]['temperature'])
        all_temperature = np.concatenate([all_temperature, tmp])
    nr_measurements = len(all_temperature)

    # calc mean and stdv
    mean_temp = np.mean(all_temperature)
    stdv_temp = np.sqrt(np.var(all_temperature))

    median_temp = np.median(all_temperature)
    quant25 = np.quantile(all_temperature, 0.25)
    quant75 = np.quantile(all_temperature, 0.75)

    if verbose:
        print("Nr. of measurements: ", nr_measurements)
        print("Mean temperature: ", mean_temp)
        print("Stdv temperature: ", stdv_temp)
        print("Median temperature: ", median_temp)
        print("25. quantile temperature: ", quant25)
        print("75. quantile temperature: ", quant75)

    # normalize
    normalized_initial_sensor_data = {}
    for sensor in initial_sensor_data.keys():
        with ChainedAssignent():
            normalized_initial_sensor_data[sensor] = initial_sensor_data[sensor].copy()
            normalized_initial_sensor_data[sensor]["temperature"] -= mean_temp
            normalized_initial_sensor_data[sensor]["temperature"] /= stdv_temp

    # subsample inital data to 10min intervals
    subsampled_initial_sensor_data = {}
    for sensor in normalized_initial_sensor_data.keys():
        with ChainedAssignent():
            df_copy = normalized_initial_sensor_data[sensor].copy()
            tmp = df_copy.resample('10Min', convention='start', origin="start").mean()
            subsampled_initial_sensor_data[sensor] = tmp.interpolate(method='linear', order=1, limit=15,
                                                                     limit_direction='both')

    # normalize unseen sensor data
    normalized_unseen_sensor_data = {}
    for sensor in unseen_sensor_data.keys():
        with ChainedAssignent():
            normalized_unseen_sensor_data[sensor] = unseen_sensor_data[sensor].copy()
            normalized_unseen_sensor_data[sensor]["temperature"] -= mean_temp
            normalized_unseen_sensor_data[sensor]["temperature"] /= stdv_temp

    # subsample unseen data to 10min intervals
    subsampled_unseen_data = {}
    for sensor in normalized_unseen_sensor_data.keys():
        with ChainedAssignent():
            df_copy = normalized_unseen_sensor_data[sensor].copy()
            tmp = df_copy.resample('10Min', convention='start', ).mean()
            subsampled_unseen_data[sensor] = tmp.interpolate(method='linear', order=1, limit=15,
                                                             limit_direction='both')

    # get coordinates of sensors
    try:
        all_coordinates = pd.read_csv(path + 'objective_functions/applications/coordinates_sensors.txt', sep=" ", header=None)
    except:
        raise Exception("Make sure, you downloaded the coordinates sensors "
                        "from http://db.csail.mit.edu/labdata/labdata.html, named it 'coordinates_sensors.txt',"
                        "and copied it into .../objective_functions/applications/.")

    all_coordinates.columns = ["moteid", "x", "y", ]
    all_coordinates.index = all_coordinates["moteid"]

    remaining_sensors = list(map(int, subsampled_unseen_data.keys()))
    coordinates = all_coordinates.take([sensor_id - 1 for sensor_id in remaining_sensors])
    defect_sensors = all_coordinates.drop(remaining_sensors, axis=0)

    # good sensors
    ids = np.asarray(coordinates['moteid'])
    xs = np.asarray(coordinates['x']) * -1 / 42
    ys = np.asarray(coordinates['y']) * -1 / 32
    good_sensor_coordinates = (ids, xs, ys)

    # defect sensors
    defect_ids = np.asarray(defect_sensors['moteid'])
    defect_xs = np.asarray(defect_sensors['x']) * -1 / 42
    defect_ys = np.asarray(defect_sensors['y']) * -1 / 32
    defect_sensor_coordinates = (defect_ids, defect_xs, defect_ys)
    return subsampled_initial_sensor_data, subsampled_unseen_data, good_sensor_coordinates, \
           defect_sensor_coordinates, {"mean": mean_temp, "stdv": stdv_temp, "nr_defect_sensors": len(defect_ids)}


def create_lengthscale_vector(day_lengthscales, night_lengthscales, overhang):
    # intervals are 0 - 8 - 18 - 8 - 18 - 0
    # since the sample frequency is 10min the length of each interval is 48 - 60 - 84 - 60 - 36
    intervals = [48, 60, 84, 60, 36 - overhang]
    lengthscales = np.empty((0, 2))

    use_night = True
    for interval in intervals:
        lengthscale = night_lengthscales if use_night else day_lengthscales
        vec = np.asarray(lengthscale).reshape(-1, 2)
        vec = np.tile(vec, (interval, 1))
        lengthscales = np.concatenate((lengthscales, vec))
        use_night = not use_night
    return lengthscales


def initialize_logger(path, name):
    logger = logging.getLogger(name)
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(path + f'{name}.log')
    fh.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    formatter = logging.Formatter('%(asctime)s - %(message)s')
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    return logger