import pandas as pd
import numpy as np
from scipy.interpolate import CubicSpline
from scipy import signal
import warnings
warnings.filterwarnings('ignore')


def deg_to_rad(deg):
    """Convert degrees to radians"""
    return deg * np.pi / 180.0


def deg_to_vec(deg):
    """Convert angle to unit vector"""
    rad = deg_to_rad(deg)
    return np.column_stack([np.cos(rad), np.sin(rad)])


def unwrap_longitude_deg(lon_deg: np.ndarray) -> np.ndarray:
    """
    Unwrap longitude sequence along time to continuous values (real cumulative change, without jumps at ±180 or 0/360)
    Input lon_deg can be in any degree representation such as [-180,180] or [0,360); output is a continuous degree sequence
    """
    lon_rad = np.deg2rad(lon_deg.astype(np.float64))
    lon_unwrapped_rad = np.unwrap(lon_rad)  # Unwrap at 2π to ensure temporal continuity
    return np.rad2deg(lon_unwrapped_rad)


class TrajectoryDatabase:
    """Trajectory database creator, generates fixed-length trajectory segments of 288 points"""

    def __init__(self, time_interval=300, min_sog=1.5, min_coast_distance=80000):
        self.time_interval = time_interval
        self.trajectory_length = 432
        self.min_sog = min_sog
        self.min_coast_distance = min_coast_distance
        self.trajectories = []

    def load_data(self, data_paths):
        """
        Load AIS data (supports multiple files)
        data_paths: list[str] or str
        """
        if isinstance(data_paths, str):
            data_paths = [data_paths]

        print("Loading data...")

        dfs = []
        for path in data_paths:
            print(f"  Reading file: {path}")
            df_part = pd.read_csv(
                path,
                usecols=['mmsi', 'postime', 'lat', 'lon', 'sog', 'cog']
            )
            dfs.append(df_part)

        self.df = pd.concat(dfs, ignore_index=True)

        # Basic preprocessing
        # self.df['postime'] = pd.to_datetime(self.df['postime'], utc=True)
        self.df = self.df.sort_values(['mmsi', 'postime']).reset_index(drop=True)

        # Longitude processing
        self.df['lon_original'] = self.df['lon']
        # self.df['lon_calc'] = self.df['lon'] % 360

        print(f"Data loading completed, total {len(self.df)} records")
        print(f"Number of ships: {self.df['mmsi'].nunique()}")

    def preprocess_and_create_database(self):
        """Preprocess data and create fixed-length trajectory database"""
        print("Creating trajectory database...")

        ship_ids = self.df['mmsi'].unique()
        total_trajectories = 0

        for ship_id in ship_ids:
            ship_data = self.df[self.df['mmsi'] == ship_id].copy()
            ship_data = ship_data.sort_values('postime')

            ship_trajectories = self._split_trajectory(ship_data)
            self.trajectories.extend(ship_trajectories)
            total_trajectories += len(ship_trajectories)

            print(f"Ship {ship_id} generated {len(ship_trajectories)} trajectory segments")

        print(f"Database creation completed, total {total_trajectories} trajectory segments")

    def _split_trajectory(self, ship_data: pd.DataFrame) -> list:
        trajectories = []
        n = len(ship_data)
        step_size = 7

        # 1) First unwrap the longitude of the entire ship as a continuous sequence (real cumulative change) in time order
        lon_unwrapped = unwrap_longitude_deg(ship_data['lon_original'].values)

        # 3) Save to ship_data for direct use during slicing
        ship_data = ship_data.copy()
        ship_data['lon_unwrapped'] = lon_unwrapped

        for start_idx in range(0, n - self.trajectory_length + 1, step_size):
            segment = ship_data.iloc[start_idx:start_idx + self.trajectory_length].copy()

            start_lon_unwrapped = segment['lon_unwrapped'].iloc[0]
            start_lon_original = segment['lon_original'].iloc[0]
            start_lat = segment['lat'].iloc[0]

            # Real cumulative longitude change: directly use continuous longitude relative to starting point (no wrapping, no shortest angular difference)
            delta_lons = segment['lon_unwrapped'].values - start_lon_unwrapped
            delta_lats = segment['lat'] - start_lat

            lat_rad = deg_to_rad(segment['lat'].values)
            lon_rad = deg_to_rad(segment['lon_original'].values)
            cog_vec = deg_to_vec(segment['cog'].values)
            sog_values = segment['sog'].values[:, None]

            merging = np.concatenate([
                lat_rad[:, None],
                lon_rad[:, None],
                sog_values * cog_vec
            ], axis=1)

            trajectory = {
                'mmsi': ship_data['mmsi'].iloc[0],
                'segment_id': f"{ship_data['mmsi'].iloc[0]}_{start_idx}",
                # 'timestamps': segment['postime'].values,
                'lons': segment['lon_original'].values,
                'lats': segment['lat'].values,
                'delta_lons': delta_lons.astype(np.float64),
                'delta_lats': delta_lats.values,
                'start_lon': start_lon_original,
                'start_lat': start_lat,
                'sog': segment['sog'].values,
                'cog': segment['cog'].values,
                'merging_features': merging.astype(np.float32),
                'lat_rad': lat_rad.astype(np.float32),
                'lon_rad': lon_rad.astype(np.float32),
                'cog_vec': cog_vec.astype(np.float32)
            }

            trajectories.append(trajectory)

        return trajectories

    def get_trajectories(self):
        return self.trajectories

    def save_database(self, file_path: str):
        import pickle
        with open(file_path, 'wb') as f:
            pickle.dump(self.trajectories, f)
        print(f"Trajectory database saved to {file_path}")

    def load_database(self, file_path: str):
        import pickle
        with open(file_path, 'rb') as f:
            self.trajectories = pickle.load(f)
        print(f"Trajectory database loaded from {file_path}, total {len(self.trajectories)} trajectories")


if __name__ == "__main__":
    db = TrajectoryDatabase(min_sog=1.5, min_coast_distance=80000)

    data_paths = [
        'data/210238000.csv',
        'data/210279000.csv',
        'data/356285000.csv',
        'data/414062000.csv',
        'data/414066000.csv',
        'data/636015239.csv'
    ]

    db.load_data(data_paths)
    db.preprocess_and_create_database()
    db.save_database('LSTM-NOAA/test_trajectories.pkl')