import numpy as np
import os
import pickle
import json
from collections import defaultdict
import random

KEEP_FEATURES = [
    "speed", "cadence", "power", "stance_time", "temperature",
    "enhanced_altitude",
    'position_lat', 'position_long',
    "step_length",
    "cycle_length16", "body_battery", "vertical_oscillation", "vertical_ratio", "performance_condition",
    "enhanced_respiration_rate",  "Calories",
    "Distance in Meters", "Elevation in Meters",
    'time_elapsed'
]
PRED_FEATURES = ["heart_rate"]

class OursDataReader:
    def __init__(
            self,
            base_dir,
            inputAtts=None,
            targetAtts=None,
            includeUser=True,
            includeSport=False,
            includeGender=False,
            includeDevice=False,
            includeTemporal=False,
            includeFullTemporal=False,
            fullTemporalLength=None,
            fn="processed_data_processed.json",
            scaleVals=True,
            trimmed_workout_len=450,
            scaleTargets="scaleVals",
            zMultiple=5,
            trainValidTestFN=None
    ):
        if inputAtts is None:
            inputAtts = KEEP_FEATURES
        if targetAtts is None:
            targetAtts = PRED_FEATURES

        self.filename = fn
        self.data_path = os.path.join(base_dir, "./data")
        os.makedirs(self.data_path, exist_ok=True)
        self.metaDataFn = os.path.join(self.data_path, fn.split(".")[0] + "_metaData.pkl")
        self.scaleVals = scaleVals
        self.trimmed_workout_len = trimmed_workout_len
        if scaleTargets == "scaleVals":
            scaleTargets = scaleVals
        self.scale_targets = scaleTargets
        self.smooth_window = 1
        self.perform_target_smoothing = True
        self.isNominal = ['gender', 'sport_type','device']
        self.isDerived = []
        self.isSequence = [att for att in inputAtts if att not in self.isNominal]
        self.inputAtts = inputAtts
        self.includeUser = includeUser
        self.includeSport = includeSport
        self.includeGender = includeGender
        self.includeDevice = includeDevice

        self.includeTemporal = includeTemporal
        self.includeFullTemporal = includeFullTemporal
        self.fullTemporalLength = fullTemporalLength
        self.targetAtts = ["tar_" + tAtt for tAtt in targetAtts]
        print("Input attributes: ", self.inputAtts)
        print("Target attributes: ", self.targetAtts)
        self.trainValidTestFN = trainValidTestFN or os.path.join(self.data_path, "ours_train_valid_test.pkl")
        self.zMultiple = zMultiple
        self.original_data = None
        self.idxMap = {}
        self.trainingSet = []
        self.validationSet = []
        self.testSet = []
        self.contextMap = {}
        self.variableMeans = {}
        self.variableStds = {}
        self.oneHotEncoders = {}
        self.oneHotMap = {}
        self.encodingLengths = {}
        self.numDataPoints = 0

    def preprocess_data(self):
        self.original_data_path = os.path.join(self.data_path, self.filename)
        self.processed_path = os.path.join(
            self.data_path,
            "processed_" + self.filename.split(".")[0] + ".npy"
        )
        if os.path.exists(self.trainValidTestFN):
            print(f"{self.trainValidTestFN} exists. Loading split...")
            with open(self.trainValidTestFN, "rb") as f:
                self.trainingSet, self.validationSet, self.testSet, self.contextMap = pickle.load(f)
        else:
            print("Generating train/valid/test split and contextMap...")
            with open(self.original_data_path, 'r') as f:
                data = json.load(f)
            user2idxs = defaultdict(list)
            for idx, d in enumerate(data):
                user2idxs[d['user_id']].append(idx)
            for user, idxs in user2idxs.items():
                idxs.sort(key=lambda i: data[i]['timestamp'][0] if isinstance(data[i]['timestamp'], list) else data[i]['timestamp'])
            contextMap = {}
            for user, idxs in user2idxs.items():
                if not idxs:
                    continue
                first_time = data[idxs[0]]['timestamp'][0] if isinstance(data[idxs[0]]['timestamp'], list) else data[idxs[0]]['timestamp']
                for i, idx in enumerate(idxs):
                    curr_time = data[idx]['timestamp'][0] if isinstance(data[idx]['timestamp'], list) else data[idx]['timestamp']
                    if i == 0:
                        since_last = 0
                        since_begin = 0
                        prev_ids = []
                    else:
                        prev_time = data[idxs[i-1]]['timestamp'][0] if isinstance(data[idxs[i-1]]['timestamp'], list) else data[idxs[i-1]]['timestamp']
                        since_last = curr_time - prev_time
                        since_begin = curr_time - first_time
                        prev_ids = idxs[:i]
                    contextMap[idx] = (since_last, since_begin, prev_ids)
            all_idxs = list(range(len(data)))
            random.shuffle(all_idxs)
            n = len(all_idxs)
            n_train = int(n * 0.8)
            n_valid = int(n * 0.1)
            trainingSet = all_idxs[:n_train]
            validationSet = all_idxs[n_train:n_train+n_valid]
            testSet = all_idxs[n_train+n_valid:]
            with open(self.trainValidTestFN, "wb") as f:
                pickle.dump((trainingSet, validationSet, testSet, contextMap), f)
            self.trainingSet = trainingSet
            self.validationSet = validationSet
            self.testSet = testSet
            self.contextMap = contextMap
        if os.path.exists(self.processed_path):
            print(f"{self.processed_path} exists. Loading preprocessed data...")
            self.original_data = np.load(self.processed_path, allow_pickle=True)[0]
            self.map_workout_id()
        else:
            assert False, "It should only execute once!"
            print("Processing raw data and saving to cache...")
            with open(self.original_data_path, 'r') as f:
                data = json.load(f)
            processed = []
            for d in data:
                new_d = {}
                for k in ['id', 'user_id', 'gender', 'device', 'sport_type', 'timestamp']:
                    if k in d:
                        new_d[k] = d[k]
                if 'timestamp' in d:
                    ts = d['timestamp']
                    if isinstance(ts, list) and len(ts) > 0:
                        t0 = ts[0]
                        new_d['time_elapsed'] = [x - t0 for x in ts]
                    elif isinstance(ts, (int, float)):
                        new_d['time_elapsed'] = [0]
                for att in self.inputAtts + [t[4:] for t in self.targetAtts]:
                    if att in d:
                        new_d[att] = self._nan_to_zero(d[att])
                for att in self.inputAtts:
                    if att not in new_d:
                        new_d[att] = [0.0] * self.trimmed_workout_len
                processed.append(new_d)
            self.original_data = processed
            self.map_workout_id()
            self.buildMetaData()
            self.scale_data()
        self.load_meta()
        self.input_dim = len(self.inputAtts)
        self.output_dim = len(self.targetAtts)
        print("Data preprocessing completed.")
        print(f"Input dim: {self.input_dim},  Output dim: {self.output_dim}")
        if self.includeFullTemporal:
            print("Precomputing full_context_input_1/2 for all workouts...")
            for idx in range(len(self.original_data)):
                fc1, fc2 = self._precompute_full_context(idx)
                self.original_data[idx]['fc1'] = fc1
                self.original_data[idx]['fc2'] = fc2
            print("Full context caching finished.")

    def _precompute_full_context(self, global_idx):
        trimmed_workout_len = self.trimmed_workout_len
        input_dim = self.input_dim
        output_dim = self.output_dim
        fullTemporalLength = self.fullTemporalLength
        if global_idx not in self.contextMap:
            prev_workouts = []
            since_last_list = []
        else:
            _, _, prev_workouts = self.contextMap[global_idx]
            since_last_list = []
            for i in range(len(prev_workouts)):
                idx = prev_workouts[i]
                if idx in self.contextMap:
                    since_last_val = self.contextMap[idx][0]
                else:
                    since_last_val = 0
                since_last_list.append(since_last_val)
        if fullTemporalLength is not None:
            prev_workouts = prev_workouts[-fullTemporalLength:]
            since_last_list = since_last_list[-fullTemporalLength:]
        N = len(prev_workouts)
        max_hist = fullTemporalLength if fullTemporalLength is not None else N
        fc1 = np.zeros([max_hist, trimmed_workout_len, input_dim + 1], dtype=np.float32)
        fc2 = np.zeros([max_hist, trimmed_workout_len, output_dim], dtype=np.float32)
        for i in range(N):
            prev_idx = prev_workouts[-(i + 1)]
            since_last_val = since_last_list[-(i + 1)]
            prev_data = self.original_data[prev_idx]
            c_in = np.zeros([input_dim, trimmed_workout_len], dtype=np.float32)
            for j, att in enumerate(self.inputAtts):
                seq = np.array(prev_data[att], dtype=np.float32)
                if len(seq) >= trimmed_workout_len:
                    seq = seq[:trimmed_workout_len]
                else:
                    padded = np.zeros(trimmed_workout_len, dtype=np.float32)
                    padded[:len(seq)] = seq
                    seq = padded
                c_in[j, :] = seq
            c_sl = np.ones([1, trimmed_workout_len], dtype=np.float32) * since_last_val
            c_in_merged = np.concatenate([c_in, c_sl], axis=0).T
            c_out = np.zeros([output_dim, trimmed_workout_len], dtype=np.float32)
            for k, tAtt in enumerate(self.targetAtts):
                seq = np.array(prev_data[tAtt[4:]], dtype=np.float32)
                if len(seq) >= trimmed_workout_len:
                    seq = seq[:trimmed_workout_len]
                else:
                    padded = np.zeros(trimmed_workout_len, dtype=np.float32)
                    padded[:len(seq)] = seq
                    seq = padded
                c_out[k, :] = seq
            c_out = c_out.T
            fc1[i, :, :] = c_in_merged
            fc2[i, :, :] = c_out
        return fc1, fc2

    def map_workout_id(self):
        self.idxMap = defaultdict(int)
        for idx, d in enumerate(self.original_data):
            self.idxMap[d['id']] = idx
        self.trainingSet = [idx for idx in self.trainingSet]
        self.validationSet = [idx for idx in self.validationSet]
        self.testSet = [idx for idx in self.testSet]
        contextMap2 = {}
        for wid in self.contextMap:
            context = self.contextMap[wid]
            idx_of_wid = wid
            converted_prev = [old_wid for old_wid in context[2]]
            contextMap2[idx_of_wid] = (context[0], context[1], converted_prev)
        self.contextMap = contextMap2

    def load_meta(self):
        self.buildMetaData()

    def buildMetaData(self):
        if os.path.isfile(self.metaDataFn):
            self.loadSummaryFile()
        else:
            print("Building data schema from scratch...")
            variableSums = defaultdict(float)
            classLabels = defaultdict(set)
            for currData in self.original_data:
                classLabels['user_id'].add(currData['user_id'])
                for att in self.isNominal:
                    if att in currData:
                        val = currData[att]
                        classLabels[att].add(val)
                for att in self.isSequence:
                    if att in currData:
                        variableSums[att] += sum(currData[att])
            oneHotEncoders = {}
            oneHotMap = {}
            encodingLengths = {}
            oneHotEncoders['user_id'], oneHotMap['user_id'] = self.buildEncoder(classLabels['user_id'])
            encodingLengths['user_id'] = 1
            for att in self.isNominal:
                oneHotEncoders[att], oneHotMap[att] = self.buildEncoder(classLabels[att])
                encodingLengths[att] = len(classLabels[att])
            for att in self.isSequence:
                encodingLengths[att] = 1
            self.numDataPoints = len(self.original_data)
            self.computeMeanStd(variableSums, self.numDataPoints, self.isSequence)
            self.oneHotEncoders = oneHotEncoders
            self.oneHotMap = oneHotMap
            self.encodingLengths = encodingLengths
            self.writeSummaryFile()

    def computeMeanStd(self, varSums, numDataPoints, attributes):
        print("Computing variable means and stddev...")
        numSequencePoints = numDataPoints * 500
        variableMeans = {}
        for att in varSums:
            variableMeans[att] = varSums[att] / numSequencePoints
        varResidualSums = defaultdict(float)
        for currData in self.original_data:
            for att in attributes:
                if att in currData:
                    arr = np.array(currData[att])
                    diff = arr - variableMeans[att]
                    varResidualSums[att] += np.sum(diff * diff)
        variableStds = {}
        for att in varResidualSums:
            variableStds[att] = np.sqrt(varResidualSums[att] / numSequencePoints + 1e-8)
        self.variableMeans = variableMeans
        self.variableStds = variableStds

    def scale_data(self, scaling=True):
        print("Scaling data (continuous attributes)...")
        base_targetAtts = [t[4:] for t in self.targetAtts]
        for idx, currentDataPoint in enumerate(self.original_data):
            for tAtt in base_targetAtts:
                if tAtt in currentDataPoint:
                    data_seq = currentDataPoint[tAtt]
                    self.original_data[idx]["tar_" + tAtt] = data_seq
            for att in self.isSequence:
                if att in currentDataPoint:
                    if f"{att}_org" not in currentDataPoint:
                        currentDataPoint[f"{att}_org"] = currentDataPoint[att][:]
                    if scaling:
                        seq = currentDataPoint[att]
                        scaled_seq = self.scaleData(seq, att, self.zMultiple)
                        self.original_data[idx][att] = scaled_seq
        np.save(self.processed_path, [self.original_data])
        print(f"Saved preprocessed data to {self.processed_path}")

    def scaleData(self, data, att, zMultiple=2):
        mean_ = self.variableMeans.get(att, 0)
        std_ = self.variableStds.get(att, 1e-6)
        if std_ < 1e-6:
            std_ = 1e-6
        arr = np.array(data, dtype=np.float32)
        arr = np.nan_to_num(arr, nan=0.0)
        zscore = (arr - mean_) / std_
        return (zscore * zMultiple).tolist()

    def median_smoothing(self, seq, window_size):
        if window_size <= 1:
            return seq
        seq = np.array(seq, dtype=np.float32)
        smoothed = []
        r = window_size // 2
        for i in range(len(seq)):
            left = max(0, i - r)
            right = min(len(seq), i + r + 1)
            median_val = np.median(seq[left:right])
            smoothed.append(median_val)
        return smoothed

    def buildEncoder(self, classLabels):
        encodingLength = len(classLabels)
        encoder_dict = {}
        mapper_dict = {}
        sorted_labels = sorted(list(classLabels))
        for i, label in enumerate(sorted_labels):
            onehot = [0] * encodingLength
            onehot[i] = 1
            encoder_dict[label] = onehot
            mapper_dict[label] = i
        return encoder_dict, mapper_dict

    def writeSummaryFile(self):
        metaObj = metaDataOurs(
            numDataPoints=self.numDataPoints,
            encodingLengths=self.encodingLengths,
            oneHotEncoders=self.oneHotEncoders,
            oneHotMap=self.oneHotMap,
            isSequence=self.isSequence,
            isNominal=self.isNominal,
            isDerived=self.isDerived,
            variableMeans=self.variableMeans,
            variableStds=self.variableStds
        )
        with open(self.metaDataFn, "wb") as f:
            pickle.dump(metaObj, f)
        print("Summary file written:", self.metaDataFn)

    def loadSummaryFile(self):
        print("Loading metadata from:", self.metaDataFn)
        with open(self.metaDataFn, "rb") as f:
            metaData = pickle.load(f)
        self.numDataPoints = metaData.numDataPoints
        self.encodingLengths = metaData.encodingLengths
        self.oneHotEncoders = metaData.oneHotEncoders
        self.oneHotMap = metaData.oneHotMap
        self.isSequence = metaData.isSequence
        self.isNominal = metaData.isNominal
        self.variableMeans = metaData.variableMeans
        self.variableStds = metaData.variableStds
        print("Metadata loaded successfully.")

    def _nan_to_zero(self, seq, fill_val=0.0):
        arr = np.asarray(seq, dtype=np.float32)
        arr = np.nan_to_num(arr, nan=fill_val)
        return arr.tolist()


class metaDataOurs:
    def __init__(
            self,
            numDataPoints,
            encodingLengths,
            oneHotEncoders,
            oneHotMap,
            isSequence,
            isNominal,
            isDerived,
            variableMeans,
            variableStds
    ):
        self.numDataPoints = numDataPoints
        self.encodingLengths = encodingLengths
        self.oneHotEncoders = oneHotEncoders
        self.oneHotMap = oneHotMap
        self.isSequence = isSequence
        self.isNominal = isNominal
        self.isDerived = isDerived
        self.variableMeans = variableMeans
        self.variableStds = variableStds 
