"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
"""

import os
import time
import json
import torch
import shutil
import ctypes
import numpy as np
from collections import defaultdict
from scipy.spatial.transform import Rotation

from utils.from_scipy import compute_euler_from_matrix
from utils.logging import get_logger
from .constants import *

log = get_logger(__name__)

class SequencesDataset:
    """
    A template class for sequences dataset in TLIO training.
    Each subclass is expected to load data in a different way, but from the same data format.
    """

    def __init__(
        self,
        data_path,
        split,
        genparams,
        only_n_sequence=-1,
        sequence_subset=None,
        normalize_sensor_data=True,
        verbose=False,
    ):
        self.data_path = data_path
        self.split = split
        self.genparams = genparams
        
        self.only_n_sequence = only_n_sequence
        self.sequence_subset = sequence_subset
        self.normalize_sensor_data = normalize_sensor_data
        self.verbose = verbose
        
        # The list of relevant sensor file names based on data_style
        self.sensor_file_basenames = self.get_sensor_file_basenames()
            
        # Index the mem-mapped files and open them (data is not read from disk here)
        self.load_list()
        if self.verbose:
            self.log_dataset_info()
            
        self.inferred_velocities = None  # To save the inferred velocities
        self.inferred_times = None  # To save the inferred timestep
        self.inferred_vel = None
        
    def get_base_sensor_name(self):
        return self.sensor_file_basenames[0]
    
    def load_list(self):
        assert torch.utils.data.get_worker_info() is None, "load_list() can only be called in main proc!"

        #list_info = np.loadtxt(
        #    os.path.join(self.data_path, f"{self.split}_list.txt"), 
        #    dtype=np.dtype(str),
        #)
        with open(os.path.join(self.data_path, f"{self.split}_list.txt")) as f:
            list_info = np.array([s.strip() for s in f.readlines() if len(s.strip()) > 0])
        
        # For picking exactly some particular sequences
        if self.sequence_subset is not None:
            to_keep = np.array([s in self.sequence_subset for s in list_info])
            assert np.count_nonzero(to_keep) == len(self.sequence_subset), \
                    f"Could not find some sequences from sequence_subset in data list"
            list_info = list_info[to_keep]

        if self.split == "train" and self.only_n_sequence > 0:
            list_info = list_info[:self.only_n_sequence]
            
        # Handle empty lists (i.e., if you don't want to do test or val or something)
        self.data_list = []
        if len(list_info) > 0:
            self.data_list = list_info
        
        # Load the descriptions of all the data (column info and num rows)
        self.data_descriptions = []
        seqs_to_remove = [] # The seqs, not the index
        for seq_id in self.data_list:
            seq_desc = {}
            valid = True
            for i, sensor_basename in enumerate(self.sensor_file_basenames):
                with open(os.path.join(self.data_path, seq_id, 
                        sensor_basename+"_description.json"), 'r') as f: 
                    d = json.load(f)
                    if i == 0 and d["num_rows"] < self.genparams.window_size:
                        valid = False
                        log.warning(f"Sequence {seq_id} being ignored since it is too short ({d['num_rows']} rows)")
                        break
                    seq_desc[sensor_basename] = d
            
            if valid:
                self.data_descriptions.append(seq_desc)
            else:
                seqs_to_remove.append(seq_id)

        # Remove too short sequences from list
        if len(seqs_to_remove) > 0:
            self.data_list = np.array([seq for seq in self.data_list if seq not in seqs_to_remove])

    def get_sensor_file_basenames(self):
        if self.genparams.data_style == "aligned":
            return [COMBINED_SENSOR_NAME]
        elif self.genparams.data_style == "resampled":
            return ['combined_flight_data']
        elif self.genparams.data_style == "raw":
            return [s for s in ALL_SENSORS_LIST if s in self.genparams.input_sensors]
        else:
            raise ValueError(f"Invalid data_style {self.genparams.data_style}")
    
    def log_dataset_info(self):
        cumulated_duration_hrs = 0
        self.max_num_rows = None
        self.min_num_rows = None
        for i, seq_id in enumerate(self.data_list):
            seq_fps = {}
            desc = self.data_descriptions[i]
            for j, sensor_basename in enumerate(self.sensor_file_basenames):
                sensor_desc = desc[sensor_basename]
                num_cols = sum([
                    int(c.split("(")[1].split(")")[0]) for c in sensor_desc["columns_name(width)"]
                ])
                cumulated_duration_hrs += 1e-6 * (sensor_desc["t_end_us"] - sensor_desc["t_start_us"]) / 60 / 60
                self.max_num_rows = (
                    sensor_desc["num_rows"] if self.max_num_rows is None
                    else max(sensor_desc["num_rows"], self.max_num_rows)
                )
                self.min_num_rows = (
                    sensor_desc["num_rows"] if self.min_num_rows is None
                    else min(sensor_desc["num_rows"], self.min_num_rows)
                )
    
        # log some statitstics
        #log.info(f"Using these sequences: {list(self.data_list)}")
        log.info(
            f"Cumulated {self.split} dataset duration is {cumulated_duration_hrs:.3f} hours"
        )
        log.info(
            f"Number of {self.split} sequences is {len(self.data_descriptions)}"
        )
        #log.info(
        #    f"Number of {self.split} samples is {self.length} "
        #    f"(decimated by {self.genparams.decimator}x)"
        #)
        log.info(f"Min/max sequences length={self.min_num_rows}, {self.max_num_rows}") 
    
    def poses_to_target(self, rot, pos):
        # Calculate relative info on the fly
        # targ is what want to regress from these features
        R_W_0 = Rotation.from_quat(rot[0:1]).as_matrix()
        R_W_i = Rotation.from_quat(rot).as_matrix()

        # NOTE R_W_i @ R_W_0.transpose() looks strange, but it is the delta rotation between the two times
        # aligned with the world frame instead of body frame.
        targ_dR_World = R_W_i @ R_W_0.transpose([0,2,1])
        targ_dt_World = pos - pos[0:1] # Displacement in global frame
        return targ_dR_World, targ_dt_World

    def unpack_data_window(self, seq_data, seq_desc, row):
        feats = {}
        ts_us_base_sensor = None
            
        # With resampled data, the "approximate_frequency" in the json file is exact,
        # so can quickly index the timestamps of sensors in different memmap files.
        base_sensor_freq = seq_desc[self.get_base_sensor_name()]["approximate_frequency_hz"]
        base_sensor_window_start_time = None
        base_sensor_window_end_time = None
        for i, sensor_name in enumerate(self.sensor_file_basenames):
            if i == 0:
                sensor_row = row
                window_size = self.genparams.window_size
            else:
                # Index the row based on sensor frequency.
                sensor_freq = seq_desc[sensor_name]["approximate_frequency_hz"]
                sensor_seq_start_time = seq_desc[sensor_name]["t_start_us"]
                # TODO off by one possible here from rounding/flooring
                sensor_row = int(1e-6*(base_sensor_window_start_time - sensor_seq_start_time) * sensor_freq)
                # TODO should calculate all the window sizes at startup so that don't
                # accidentally get an off-by-one window size error from float errors
                window_size = int(self.genparams.window_size * sensor_freq / base_sensor_freq)
        
            data_chunk = seq_data[sensor_name][sensor_row:sensor_row+window_size]
            # Make sure idx was valid with sufficient padding for window
            assert data_chunk.shape[0] == window_size
            
            ts_us = data_chunk[:,0:1]
            ts_us_base_sensor = np.copy(ts_us)
            
            base_sensor_window_start_time = ts_us[0]
            base_sensor_window_end_time = ts_us[-1]
            # GT data comes from base sensor
            rot, pos, vel = data_chunk[:,-10:-6], data_chunk[:,-6:-3], data_chunk[:,-3:]
            
            
            # ############################################################# #
            # ############################################################# #

            noisy_vel_feat = data_chunk[:, 14:17] 
            covariance_feat = data_chunk[:, 17:26]
            acc_feature = data_chunk[:, 4:7] 
            
            combined_feature1 = np.concatenate((noisy_vel_feat, covariance_feat), axis=1)
            combined_feature2 = np.concatenate((noisy_vel_feat, acc_feature), axis=1)
            combined_feature3 = np.concatenate((noisy_vel_feat, acc_feature, covariance_feat), axis=1)

            feats = {
                "velocity": noisy_vel_feat.astype(np.float32).T,
                
                ##### Use when using covariance as additional feature #####
                "vel_and_cov": combined_feature1.astype(np.float32).T,
                "vel_and_acc": combined_feature2.astype(np.float32).T,
                "vel_acc_cov": combined_feature3.astype(np.float32).T

            }

            rot = data_chunk[:, 7:11]
            pos = data_chunk[:, 11:14]
            vel = data_chunk[:, 14:17]

            # ############################################################# #
            # ############################################################# #
        
        gt_data = ts_us_base_sensor, rot.astype(np.float32), pos.astype(np.float32), vel.astype(np.float32)
        return feats, gt_data

    def data_chunk_from_seq_data(self, seq_data, seq_desc, row):
        if "bodyframe" in self.data_path:
            body_frame = True
            body_frame_velocity = True
        else:
            body_frame = False
            body_frame_velocity = False
            
        
        feats, gt_data = self.unpack_data_window(seq_data, seq_desc, row)

        # Normalize the raw sensor data into something better for learning (sensor-dependent)
        if self.normalize_sensor_data:
            feats = self.normalize_feats(feats)
        
        ts_us, rot, pos, vel = gt_data
        # print(ts_us[0])
        targ_dR_World, targ_dt_World = self.poses_to_target(rot, pos)

        inverse_rotation_matrices = np.transpose(Rotation.from_quat(rot).as_matrix(), axes=(0, 2, 1))
        targ_dt_Body = np.einsum("bij,bj->bi", inverse_rotation_matrices, targ_dt_World)
        
        R_world_gla = np.eye(3)
        if not body_frame:
            # print('data processing for world-frame!')
            if self.genparams.express_in_t0_yaw_normalized_frame:
                assert False
                R_W_0 = Rotation.from_quat(rot[0:1]).as_matrix()
                angles_t0 = compute_euler_from_matrix(
                    R_W_0, "xyz", extrinsic=True
                )
                ri_z = angles_t0[0,2]
                c = np.cos(ri_z)
                s = np.sin(ri_z)
                R_world_gla = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
                targ_dt_World = np.einsum("ji,tj->ti", R_world_gla, targ_dt_World)

                # Only IMU and mag data need to be rotated (not barometer)
                for k, v in feats.items():
                    if "imu" in k:
                        assert feats[k].shape[0] == 6
                        feats[k][:3] = np.einsum("ji,jt->it", R_world_gla, feats[k][:3])
                        feats[k][3:] = np.einsum("ji,jt->it", R_world_gla, feats[k][3:])
                    elif "mag" in k:
                        assert feats[k].shape[0] == 3
                        feats[k] = np.einsum("ji,jt->it", R_world_gla, feats[k])
        if body_frame:
            ## represent targ_dt_World (displacement) in body frame in each step
            if not body_frame_velocity : 
                # for i in range(len(rot)):
                #     R_W_i = Rotation.from_quat(rot[i]).as_matrix()
                #     targ_dt_World[i] = np.einsum("ji,tj->ti", R_W_i, targ_dt_World[i].reshape(-1,3)).reshape(-1)
                
                R_W = Rotation.from_quat(rot).as_matrix()
                targ_dt_World_reshaped = targ_dt_World.reshape(-1, 3, 1)
                targ_dt_World = np.matmul(R_W.transpose(0, 2, 1), targ_dt_World_reshaped)
                targ_dt_World = targ_dt_World.reshape(-1, 3)
                # print('check targdt and vel : ', targ_dt_World[0,:] - vel[0,:])
        
            align_last_imu = True
            # align_last_imu = False
            align_first_imu = False
            
            if align_last_imu and align_first_imu:
                assert False
            if align_last_imu:
            ## 1. align in last IMU frame ##   
                R_W_last = Rotation.from_quat(rot[-1:]).as_matrix().transpose(0, 2, 1)
                R_W_last = np.squeeze(R_W_last, axis = 0)
                R_n_W = Rotation.from_quat(rot).as_matrix()
                
                R_n_last = np.einsum('ij,tjk->tik', R_W_last, R_n_W)  # Resulting shape: (200, 3, 3)
            if align_first_imu:
            ## 2. align in first IMU frame ##   
                R_W_first = Rotation.from_quat(rot[0:1]).as_matrix().transpose(0, 2, 1)
                R_W_first = np.squeeze(R_W_first, axis = 0)
                R_n_W = Rotation.from_quat(rot).as_matrix()
                
                R_n_first = np.einsum('ij,tjk->tik', R_W_first, R_n_W)  # Resulting shape: (200, 3, 3)
            
            # Only IMU and mag data need to be rotated (not barometer)
            for k, v in feats.items():
                if "imu" in k:
                    # assert feats[k].shape[0] == 6
                    # feats[k][:3] = np.einsum("ji,jt->it", R_world_gla, feats[k][:3])
                    #only align acceleration
                    # print(feats[k][3:].shape) #3*200
                    if align_last_imu:
                        feats[k][3:6] = np.einsum("tij,jt->it", R_n_last, feats[k][3:6])
                        if feats[k].shape[0] == 9:
                            feats[k][6:9] = np.einsum("tij,jt->it", R_n_last, feats[k][6:9])
                    if align_first_imu:
                        feats[k][3:] = np.einsum("tij,jt->it", R_n_first, feats[k][3:])
                        
                        
                        
        
        # may return multiple windows, so place them all in here for convenience.
        if body_frame : 
            windows = {
                "main": {
                    "ts_us": ts_us,
                    "feats": feats,
                    "targ_dR_World": targ_dR_World.astype(np.float32),
                    "targ_dt_World": targ_dt_World.astype(np.float32),
                    "targ_dt_Body": targ_dt_Body.astype(np.float32),
                    "vel_Body": vel.astype(np.float32),
                    "R_world_gla": R_world_gla,
                }
            }
        else:
            windows = {
                "main": {
                    "ts_us": ts_us,
                    "feats": feats,
                    "targ_dR_World": targ_dR_World.astype(np.float32),
                    "targ_dt_World": targ_dt_World.astype(np.float32),
                    "vel_World": vel.astype(np.float32),
                    "R_world_gla": R_world_gla,
                }
            }

        # ###################### 핵심 수정 부분 ###################### #
        # windows 딕셔너리에 절대 위치(pos)를 'targ_pos_World' 라는 새 키로 추가합니다.
        # pos 변수는 gt_data에서 이미 추출했습니다.
        windows["main"]["targ_pos_World"] = pos.astype(np.float32)
        # ######################################################### #
        
        return seq_desc[self.get_base_sensor_name()], windows
    
    def normalize_feats(self, feats):
        """
        Normalize the sensor data from its raw form to some normalized form, typically in [-1,1] or [0,1].
        """
        
        new_feats = {}
        for sensor_name, feat in feats.items():
            # Note that all feat are [1+C,T] where C is channels in sensor data and T is tme dimension.
            # The 1+ is because the sensor data is concatenated with normalized time stamp.
            new_feat = np.copy(feat)
            # Check for nan/inf here (sometimes can pop up in the data)
            new_feat[~np.isfinite(new_feat)] = 0.0
            """  NOTE makes values too small, and disrupts bias perturbation logic
            if "imu" in sensor_name:
                assert new_feat.shape[0] == 6
                # See T74692750 for more info.
                # Out of the two IMUs, the one with the max range is at +/-8G and +/-1000 deg/sec.
                # Normalize by this one so that both IMU values have the same meaning, and are normalized in [-1,1]
                minmax_acc_range_g = 8 # In unit of Gs
                minmax_ang_vel_range_deg_per_sec = 1000
                # IMU values should be in [-1,1] after this
                new_feat[:3] = new_feat[:3] / (minmax_ang_vel_range_deg_per_sec / 180 * np.pi) # gyro
                new_feat[3:6] = new_feat[3:6] / (minmax_acc_range_g * 9.81) # accelerometer
            """
            if "mag" in sensor_name:
                assert new_feat.shape[0] == 3
                # Convert to Gauss, which is closer to 1 in magnitude (Earth's field is around .25-.65 Gauss, and 
                # can be negative here since the magnetomete returns a magnetic field vector instead of magnitude)
                GAUSS_IN_TESLA = 10_000
                new_feat[:3] = new_feat[:3] * GAUSS_IN_TESLA
            if "barom" in sensor_name:
                assert new_feat.shape[0] == 2
                # Pressure converted to bar and normalized heuristically to fit into [-1,1] better.
                # Setting -1,1 to be the min/max pressure/temp ever recorded leads to very small differences
                # in the values for normal situations, so just picked min/max based on some normal daily values on Earth.
                PA_IN_BAR = 100_000
                """
                avg_bar = 1.01325 # Average barometric pressure on earth
                max_bar_deviation = 0.01 # plus/minus avg is what are considering
                min_bar, max_bar = avg_bar - max_bar_deviation, avg_bar + max_bar_deviation
                new_feat[0] = 2 * (new_feat[0] / PA_IN_BAR - min_bar) / (max_bar - min_bar) - 1
                min_temp = -100
                max_temp = 100
                new_feat[1] = 2 * (new_feat[1] - min_temp) / (max_temp - min_temp) - 1
                """
                new_feat[0] /= PA_IN_BAR # convert pa to bar

            new_feats[sensor_name] = new_feat
        
        return new_feats
    
    def load_data_chunk(self, seq_idx, row):
        raise NotImplementedError("Did not override load_data_chunk!!!")

    def load_and_preprocess_data_chunk(self, seq_idx, row_in_seq, num_rows_in_seq):
        # If training, randomize the row a bit so that can get better coverage of the data
        # while still respecting the decimator and indexing.
        if self.split == "train":
            row_in_seq = min(num_rows_in_seq-1, row_in_seq + np.random.randint(self.genparams.decimator))
        meta_dict, windows = self.load_data_chunk(seq_idx, row_in_seq)

        ret = {
            "seq_id": self.data_list[seq_idx],
        }
        ret.update(windows["main"]) # Main target and GT data corresponding to seq_idx and row_in_seq
        return ret

    ##########################################################################
    # Functions needed by IntegrateRoninCallback
    ##########################################################################

    def get_ts_last_imu_us(self, seq_idx=0):
        raise NotImplementedError("Did not override get_ts_last_imu_us!!!")

    def get_gt_traj_center_window_times(self, seq_idx=0):
        raise NotImplementedError("Did not override get_gt_traj_center_window_times!!!")
    
    def get_gt_traj_end_window_times(self, seq_idx=0):
        raise NotImplementedError("Did not override get_gt_traj_center_window_times!!!")
