from typing import Union

import numpy as np
import os
import pickle
from collections import defaultdict

from torch.utils.data import Dataset
from haversine import haversine

from ours_dataloader import OursDataReader


class EndomondoDataReader:
    def __init__(
            self,
            inputAtts,
            targetAtts,
            base_dir,
            includeUser=True,
            includeSport=False,
            includeGender=False,
            includeTemporal=False,
            includeFullTemporal=False,
            fullTemporalLength=None,
            includeDevice=False,
            fn="endomondoHR_proper.json",
            scaleVals=True,
            trimmed_workout_len=450,
            scaleTargets="scaleVals",
            zMultiple=5,
            trainValidTestFN=None
    ):
        self.filename = fn
        self.data_path = os.path.join(base_dir,"./data")
        self.metaDataFn = os.path.join(self.data_path,fn.split(".")[0] + "_metaData.pkl")

        self.scaleVals = scaleVals
        self.trimmed_workout_len = trimmed_workout_len
        if scaleTargets == "scaleVals":
            scaleTargets = scaleVals
        self.scale_targets = scaleTargets

        self.smooth_window = 1
        self.perform_target_smoothing = True

        self.isNominal = ['gender', 'sport']
        self.isDerived = ['time_elapsed', 'distance', 'derived_speed', 'since_begin', 'since_last']
        self.isSequence = ['altitude', 'heart_rate', 'latitude', 'longitude'] + self.isDerived

        self.inputAtts = inputAtts
        self.includeUser = includeUser
        self.includeSport = includeSport
        self.includeGender = includeGender
        self.includeTemporal = includeTemporal
        self.includeFullTemporal = includeFullTemporal
        self.fullTemporalLength = fullTemporalLength
        self.includeDevice = includeDevice

        self.targetAtts = ["tar_" + tAtt for tAtt in targetAtts]

        print("Input attributes: ", self.inputAtts)
        print("Target attributes: ", self.targetAtts)

        self.trainValidTestFN = trainValidTestFN
        self.zMultiple = zMultiple

        self.original_data = None
        self.idxMap = {}
        self.trainingSet = []
        self.validationSet = []
        self.testSet = []
        self.contextMap = {}

        self.variableMeans = {}
        self.variableStds = {}
        self.oneHotEncoders = {}
        self.oneHotMap = {}
        self.encodingLengths = {}
        self.numDataPoints = 0

    def preprocess_data(self):
        self.original_data_path = os.path.join(self.data_path, self.filename)
        self.processed_path = os.path.join(
            self.data_path,
            "processed_" + self.filename.split(".")[0] + ".npy"
        )

        self.loadTrainValidTest()

        if os.path.exists(self.processed_path):
            print(f"{self.processed_path} exists. Loading preprocessed data...")
            self.original_data = np.load(self.processed_path, allow_pickle=True)[0]
            self.map_workout_id()
        else:
            print("###############",self.processed_path)
            print("Load original data from:", self.original_data_path)
            with open(self.original_data_path, 'r') as f:
                self.original_data = [eval(line) for line in f]

            self.map_workout_id()
            self.derive_data()
            self.buildMetaData()
            self.scale_data()

        self.load_meta()

        self.input_dim = len(self.inputAtts)
        self.output_dim = len(self.targetAtts)
        print("Data preprocessing completed.")
        print(f"Input dim: {self.input_dim},  Output dim: {self.output_dim}")

        if self.includeFullTemporal:
            print("Precomputing full_context_input_1/2 for all workouts...")
            for idx in range(len(self.original_data)):
                fc1, fc2 = self._precompute_full_context(idx)
                self.original_data[idx]['fc1'] = fc1
                self.original_data[idx]['fc2'] = fc2
            print("Full context caching finished.")

    def _precompute_full_context(self, global_idx):
        trimmed_workout_len = self.trimmed_workout_len
        input_dim = self.input_dim
        output_dim = self.output_dim
        fullTemporalLength = self.fullTemporalLength

        if global_idx not in self.contextMap:
            prev_workouts = []
            since_last_list = []
        else:
            _, _, prev_workouts = self.contextMap[global_idx]
            since_last_list = []
            for i in range(len(prev_workouts)):
                idx = prev_workouts[i]
                if idx in self.contextMap:
                    since_last_val = self.contextMap[idx][0]
                else:
                    since_last_val = 0
                since_last_list.append(since_last_val)

        if fullTemporalLength is not None:
            prev_workouts = prev_workouts[-fullTemporalLength:]
            since_last_list = since_last_list[-fullTemporalLength:]

        N = len(prev_workouts)
        max_hist = fullTemporalLength if fullTemporalLength is not None else N

        fc1 = np.zeros([max_hist, trimmed_workout_len, input_dim + 1], dtype=np.float32)
        fc2 = np.zeros([max_hist, trimmed_workout_len, output_dim], dtype=np.float32)

        for i in range(N):
            prev_idx = prev_workouts[-(i + 1)]
            since_last_val = since_last_list[-(i + 1)]
            prev_data = self.original_data[prev_idx]

            c_in = np.zeros([input_dim, trimmed_workout_len], dtype=np.float32)
            for j, att in enumerate(self.inputAtts):
                seq = np.array(prev_data[att], dtype=np.float32)
                if len(seq) >= trimmed_workout_len:
                    seq = seq[:trimmed_workout_len]
                else:
                    padded = np.zeros(trimmed_workout_len, dtype=np.float32)
                    padded[:len(seq)] = seq
                    seq = padded
                c_in[j, :] = seq

            c_sl = np.ones([1, trimmed_workout_len], dtype=np.float32) * since_last_val
            c_in_merged = np.concatenate([c_in, c_sl], axis=0).T

            c_out = np.zeros([output_dim, trimmed_workout_len], dtype=np.float32)
            for k, tAtt in enumerate(self.targetAtts):
                seq = np.array(prev_data[tAtt], dtype=np.float32)
                if len(seq) >= trimmed_workout_len:
                    seq = seq[:trimmed_workout_len]
                else:
                    padded = np.zeros(trimmed_workout_len, dtype=np.float32)
                    padded[:len(seq)] = seq
                    seq = padded
                c_out[k, :] = seq
            c_out = c_out.T

            fc1[i, :, :] = c_in_merged
            fc2[i, :, :] = c_out

        return fc1, fc2

    def loadTrainValidTest(self):
        if not self.trainValidTestFN:
            raise ValueError("trainValidTestFN not specified. Need a pkl with (train,valid,test,contextMap).")
        with open(self.trainValidTestFN, "rb") as f:
            self.trainingSet, self.validationSet, self.testSet, self.contextMap = pickle.load(f)
            print("train/valid/test set size = {}/{}/{}".format(
                len(self.trainingSet), len(self.validationSet), len(self.testSet)))

            print("dataset split loaded.")

    def map_workout_id(self):
        self.idxMap = defaultdict(int)
        for idx, d in enumerate(self.original_data):
            self.idxMap[d['id']] = idx

        self.trainingSet = [self.idxMap[wid] for wid in self.trainingSet]
        self.validationSet = [self.idxMap[wid] for wid in self.validationSet]
        self.testSet = [self.idxMap[wid] for wid in self.testSet]

        contextMap2 = {}
        for wid in self.contextMap:
            context = self.contextMap[wid]
            idx_of_wid = self.idxMap[wid]
            converted_prev = [self.idxMap[old_wid] for old_wid in context[2]]
            contextMap2[idx_of_wid] = (context[0], context[1], converted_prev)
        self.contextMap = contextMap2

    def load_meta(self):
        self.buildMetaData()

    def derive_data(self):
        print("Deriving data (e.g. derived_speed, distance, etc.)...")
        for idx, d in enumerate(self.original_data):
            for att in self.isDerived:
                self.original_data[idx][att] = self.deriveData(att, d, idx)

    def deriveData(self, att, currentDataPoint, idx):
        if att == 'time_elapsed':
            timestamps = currentDataPoint['timestamp']
            initialTime = timestamps[0]
            return [x - initialTime for x in timestamps]

        elif att == 'distance':
            lats = currentDataPoint['latitude']
            longs = currentDataPoint['longitude']
            distances = [0]
            for i in range(1, len(lats)):
                dist_km = haversine((lats[i - 1], longs[i - 1]), (lats[i], longs[i]))
                distances.append(dist_km)
            return distances

        elif att == 'derived_speed':
            distances = self.deriveData('distance', currentDataPoint, idx)
            timestamps = currentDataPoint['timestamp']
            derivedSpeeds = [0]
            for i in range(1, len(timestamps)):
                delta_t = timestamps[i] - timestamps[i - 1]
                if delta_t > 0:
                    speed = 3600 * distances[i] / delta_t
                    derivedSpeeds.append(speed)
                else:
                    derivedSpeeds.append(derivedSpeeds[-1])
            return derivedSpeeds

        elif att == 'since_last':
            if idx in self.contextMap:
                total_time = self.contextMap[idx][0]
            else:
                total_time = 0
            return np.ones(self.trimmed_workout_len) * total_time

        elif att == 'since_begin':
            if idx in self.contextMap:
                total_time = self.contextMap[idx][1]
            else:
                total_time = 0
            return np.ones(self.trimmed_workout_len) * total_time

        else:
            raise Exception(f"No such derived data attribute: {att}")

    def buildMetaData(self):
        if os.path.isfile(self.metaDataFn):
            self.loadSummaryFile()
        else:
            print("Building data schema from scratch...")
            variableSums = defaultdict(float)
            classLabels = defaultdict(set)

            for currData in self.original_data:
                classLabels['userId'].add(currData['userId'])
                for att in self.isNominal:
                    val = currData[att]
                    classLabels[att].add(val)

                for att in self.isSequence:
                    variableSums[att] += sum(currData[att])

            oneHotEncoders = {}
            oneHotMap = {}
            encodingLengths = {}
            oneHotEncoders['userId'], oneHotMap['userId'] = self.buildEncoder(classLabels['userId'])
            encodingLengths['userId'] = 1

            # nominal
            for att in self.isNominal:
                oneHotEncoders[att], oneHotMap[att] = self.buildEncoder(classLabels[att])
                encodingLengths[att] = len(classLabels[att])

            for att in self.isSequence:
                encodingLengths[att] = 1

            self.numDataPoints = len(self.original_data)
            self.computeMeanStd(variableSums, self.numDataPoints, self.isSequence)
            self.oneHotEncoders = oneHotEncoders
            self.oneHotMap = oneHotMap
            self.encodingLengths = encodingLengths
            self.writeSummaryFile()

    def computeMeanStd(self, varSums, numDataPoints, attributes):
        print("Computing variable means and stddev...")
        numSequencePoints = numDataPoints * 500

        variableMeans = {}
        for att in varSums:
            variableMeans[att] = varSums[att] / numSequencePoints

        varResidualSums = defaultdict(float)

        for currData in self.original_data:
            for att in attributes:
                arr = np.array(currData[att])
                diff = arr - variableMeans[att]
                varResidualSums[att] += np.sum(diff * diff)

        variableStds = {}
        for att in varResidualSums:
            variableStds[att] = np.sqrt(varResidualSums[att] / numSequencePoints + 1e-8)

        self.variableMeans = variableMeans
        self.variableStds = variableStds

    def scale_data(self, scaling=True):
        print("Scaling data (continuous attributes)...")
        base_targetAtts = ['heart_rate', 'derived_speed']
        for idx, currentDataPoint in enumerate(self.original_data):
            for tAtt in base_targetAtts:
                data_seq = currentDataPoint[tAtt]
                if self.perform_target_smoothing:
                    data_seq = self.median_smoothing(data_seq, self.smooth_window)
                if self.scale_targets:
                    data_seq = self.scaleData(data_seq, tAtt, self.zMultiple)
                self.original_data[idx]["tar_" + tAtt] = data_seq

            for att in self.isSequence:
                if f"{att}_org" not in currentDataPoint:
                    currentDataPoint[f"{att}_org"] = currentDataPoint[att][:]
                if scaling:
                    seq = currentDataPoint[att]
                    scaled_seq = self.scaleData(seq, att, self.zMultiple)
                    self.original_data[idx][att] = scaled_seq

        for d in self.original_data:
            d.pop('url', None)
            d.pop('speed', None)

        np.save(self.processed_path, [self.original_data])
        print(f"Saved preprocessed data to {self.processed_path}")

    def scaleData(self, data, att, zMultiple=2):
        mean_ = self.variableMeans[att]
        std_ = self.variableStds[att] if self.variableStds[att] > 1e-6 else 1e-6
        arr = np.array(data, dtype=np.float32)
        zscore = (arr - mean_) / std_
        return (zscore * zMultiple).tolist()

    def median_smoothing(self, seq, window_size):
        if window_size <= 1:
            return seq
        seq = np.array(seq, dtype=np.float32)
        smoothed = []
        r = window_size // 2
        for i in range(len(seq)):
            left = max(0, i - r)
            right = min(len(seq), i + r + 1)
            median_val = np.median(seq[left:right])
            smoothed.append(median_val)
        return smoothed

    def buildEncoder(self, classLabels):
        encodingLength = len(classLabels)
        encoder_dict = {}
        mapper_dict = {}
        sorted_labels = sorted(list(classLabels))
        for i, label in enumerate(sorted_labels):
            onehot = [0] * encodingLength
            onehot[i] = 1
            encoder_dict[label] = onehot
            mapper_dict[label] = i
        return encoder_dict, mapper_dict

    def writeSummaryFile(self):
        metaObj = metaDataEndomondo(
            numDataPoints=self.numDataPoints,
            encodingLengths=self.encodingLengths,
            oneHotEncoders=self.oneHotEncoders,
            oneHotMap=self.oneHotMap,
            isSequence=self.isSequence,
            isNominal=self.isNominal,
            isDerived=self.isDerived,
            variableMeans=self.variableMeans,
            variableStds=self.variableStds
        )
        with open(self.metaDataFn, "wb") as f:
            pickle.dump(metaObj, f)
        print("Summary file written:", self.metaDataFn)

    def loadSummaryFile(self):
        print("Loading metadata from:", self.metaDataFn)
        with open(self.metaDataFn, "rb") as f:
            metaData = pickle.load(f)
        self.numDataPoints = metaData.numDataPoints
        self.encodingLengths = metaData.encodingLengths
        self.oneHotEncoders = metaData.oneHotEncoders
        self.oneHotMap = metaData.oneHotMap
        self.isSequence = metaData.isSequence
        self.isNominal = metaData.isNominal
        self.variableMeans = metaData.variableMeans
        self.variableStds = metaData.variableStds
        print("Metadata loaded successfully.")

class metaDataEndomondo:
    def __init__(
            self,
            numDataPoints,
            encodingLengths,
            oneHotEncoders,
            oneHotMap,
            isSequence,
            isNominal,
            isDerived,
            variableMeans,
            variableStds
    ):
        self.numDataPoints = numDataPoints
        self.encodingLengths = encodingLengths
        self.oneHotEncoders = oneHotEncoders
        self.oneHotMap = oneHotMap
        self.isSequence = isSequence
        self.isNominal = isNominal
        self.isDerived = isDerived
        self.variableMeans = variableMeans
        self.variableStds = variableStds

class EndomondoDataset(Dataset):
    def __init__(
        self,
        data_reader: Union[EndomondoDataReader,OursDataReader],
        mode="train"
    ):
        super().__init__()
        self.data_reader = data_reader
        self.mode = mode

        if mode == "train":
            self.indices = data_reader.trainingSet
        elif mode == "valid":
            self.indices = data_reader.validationSet
        elif mode == "test":
            self.indices = data_reader.testSet
        else:
            raise ValueError("Invalid mode. Must be one of ['train','valid','test']")

        self.inputAtts = data_reader.inputAtts
        self.targetAtts = data_reader.targetAtts
        self.includeUser = data_reader.includeUser
        self.includeSport = data_reader.includeSport
        self.includeGender = data_reader.includeGender
        self.includeTemporal = data_reader.includeTemporal
        self.trimmed_workout_len = data_reader.trimmed_workout_len
        self.includeFullTemporal = data_reader.includeFullTemporal
        self.fullTemporalLength = data_reader.fullTemporalLength
        self.includeDevice = data_reader.includeDevice

        self.input_dim = len(self.inputAtts)
        self.output_dim = len(self.targetAtts)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        global_idx = self.indices[idx]
        raw_data = self.data_reader.original_data[global_idx]
        workoutid = raw_data['id']

        inputs = np.zeros([self.input_dim, self.trimmed_workout_len], dtype=np.float32)
        for i, att in enumerate(self.inputAtts):
            if att == 'time_elapsed':
                val = raw_data[att]
                if len(val) < self.trimmed_workout_len:
                    arr = np.ones(self.trimmed_workout_len) * val[-1]
                else:
                    arr = np.array(val[:self.trimmed_workout_len])
                inputs[i, :] = arr
            else:
                arr = np.array(raw_data[att], dtype=np.float32)
                if len(arr) >= self.trimmed_workout_len:
                    arr = arr[:self.trimmed_workout_len]
                else:
                    padded = np.zeros(self.trimmed_workout_len, dtype=np.float32)
                    padded[:len(arr)] = arr
                    arr = padded
                inputs[i, :] = arr

        inputs = inputs.T
        outputs = np.zeros([self.output_dim, self.trimmed_workout_len], dtype=np.float32)
        for j, tAtt in enumerate(self.targetAtts):
            val = np.array(raw_data[tAtt], dtype=np.float32)
            if len(val) >= self.trimmed_workout_len:
                val = val[:self.trimmed_workout_len]
            else:
                padded = np.zeros(self.trimmed_workout_len, dtype=np.float32)
                padded[:len(val)] = val
                val = padded
            outputs[j, :] = val
        outputs = outputs.T

        inputs_dict = {"input": inputs}
        if self.includeUser:
            possible_keys = ['userId','user_id']
            for key in possible_keys:
                if key in raw_data:
                    user_id_val = raw_data[key]
                    break
            else:
                raise ValueError("User Info not exists in the dataset!")
            user_val = self.data_reader.oneHotMap[key].get(user_id_val)
            user_inputs = np.ones([self.trimmed_workout_len, 1], dtype=np.float32) * user_val
            inputs_dict["user_input"] = user_inputs

        if self.includeSport:
            possible_keys = ['sport','sport_type']
            for key in possible_keys:
                if key in raw_data:
                    sport_val = raw_data[key]
                    break
            else:
                raise ValueError("User Info not exists in the dataset!")

            sport_idx = self.data_reader.oneHotMap[key].get(sport_val)
            sport_inputs = np.ones([self.trimmed_workout_len, 1], dtype=np.float32) * sport_idx
            inputs_dict["sport_input"] = sport_inputs

        if self.includeGender:
            device_val = raw_data['gender']
            gender_idx = self.data_reader.oneHotMap['gender'].get(device_val)
            gender_inputs = np.ones([self.trimmed_workout_len, 1], dtype=np.float32) * gender_idx
            inputs_dict["gender_input"] = gender_inputs

        if self.includeDevice:
            device_val = raw_data['device']
            gender_idx = self.data_reader.oneHotMap['device'].get(device_val)
            gender_inputs = np.ones([self.trimmed_workout_len, 1], dtype=np.float32) * gender_idx
            inputs_dict["device_input"] = gender_inputs

        if self.includeTemporal:
            if global_idx not in self.data_reader.contextMap:
                context_input_1 = np.zeros([self.trimmed_workout_len, self.input_dim + 1], dtype=np.float32)
                context_input_2 = np.zeros([self.trimmed_workout_len, self.output_dim], dtype=np.float32)
            else:
                since_last_val, _, prev_workouts = self.data_reader.contextMap[global_idx]
                if len(prev_workouts) == 0:
                    context_input_1 = np.zeros([self.trimmed_workout_len, self.input_dim + 1], dtype=np.float32)
                    context_input_2 = np.zeros([self.trimmed_workout_len, self.output_dim], dtype=np.float32)
                else:
                    prev_idx = prev_workouts[-1]
                    prev_data = self.data_reader.original_data[prev_idx]

                    c_in = np.zeros([self.input_dim, self.trimmed_workout_len], dtype=np.float32)
                    for i, att in enumerate(self.inputAtts):
                        seq = np.array(prev_data[att], dtype=np.float32)
                        if len(seq) >= self.trimmed_workout_len:
                            seq = seq[:self.trimmed_workout_len]
                        else:
                            padded = np.zeros(self.trimmed_workout_len, dtype=np.float32)
                            padded[:len(seq)] = seq
                            seq = padded
                        c_in[i,:] = seq

                    c_sl = np.ones([1, self.trimmed_workout_len], dtype=np.float32) * since_last_val
                    c_in_merged = np.concatenate([c_in, c_sl], axis=0).T

                    c_out = np.zeros([self.output_dim, self.trimmed_workout_len], dtype=np.float32)
                    for j, tAtt in enumerate(self.targetAtts):
                        seq = np.array(prev_data[tAtt], dtype=np.float32)
                        if len(seq) >= self.trimmed_workout_len:
                            seq = seq[:self.trimmed_workout_len]
                        else:
                            padded = np.zeros(self.trimmed_workout_len, dtype=np.float32)
                            padded[:len(seq)] = seq
                            seq = padded
                        c_out[j,:] = seq
                    c_out = c_out.T  # (time, output_dim)

                    context_input_1 = c_in_merged
                    context_input_2 = c_out

            inputs_dict["context_input_1"] = context_input_1
            inputs_dict["context_input_2"] = context_input_2

        if self.includeFullTemporal:
            fc1 = raw_data.get('fc1', None)
            fc2 = raw_data.get('fc2', None)
            if fc1 is None or fc2 is None:
                fc1, fc2 = self.data_reader._precompute_full_context(global_idx)
                raw_data['fc1'] = fc1
                raw_data['fc2'] = fc2
            inputs_dict["full_context_input_1"] = fc1
            inputs_dict["full_context_input_2"] = fc2

        inputs_dict["lat_seq"] = raw_data["latitude"][:self.trimmed_workout_len]
        inputs_dict["lon_seq"] = raw_data["longitude"][:self.trimmed_workout_len]

        return inputs_dict, outputs, workoutid

    @staticmethod
    def pad_full_context(batch):
        import torch
        from torch.nn.utils.rnn import pad_sequence

        batch_inputs = [b[0] for b in batch]
        batch_outputs = [b[1] for b in batch]
        batch_ids = [b[2] for b in batch]

        fc1_list = [torch.as_tensor(inp["full_context_input_1"]) for inp in batch_inputs]
        fc2_list = [torch.as_tensor(inp["full_context_input_2"]) for inp in batch_inputs]
        fc1_padded = pad_sequence(fc1_list, batch_first=True)
        fc2_padded = pad_sequence(fc2_list, batch_first=True)

        batch_inputs_stacked = {
            "full_context_input_1": fc1_padded,
            "full_context_input_2": fc2_padded,
        }
        for k in batch_inputs[0]:
            if k not in ["full_context_input_1", "full_context_input_2"]:
                batch_inputs_stacked[k] = torch.stack([torch.as_tensor(inp[k]) for inp in batch_inputs])
        batch_outputs = torch.stack([torch.as_tensor(b) for b in batch_outputs])
        return batch_inputs_stacked, batch_outputs, batch_ids