import numpy as np
import torch
import torch.nn as nn
from scipy.interpolate import CubicSpline
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import pickle
from process import TrajectoryDatabase
from utils.metrics import frechet_distance, curvature_calculation
import random
import os
import math
from tqdm import tqdm


def deg_to_rad(deg_tensor):
    """Batch convert degrees to radians"""
    return deg_tensor * (math.pi / 180.0)


def rad_to_deg(rad_tensor):
    """Batch convert radians to degrees"""
    return rad_tensor * (180.0 / math.pi)


def normalize_seq(seq: np.ndarray, eps: float = 1e-8):
    """:
    x_norm = (x - min(x)) / (max(x) - min(x) + eps)
    Returns: x_norm, x_min, x_max
    """
    seq = np.asarray(seq, dtype=np.float64)
    x_min = float(np.min(seq))
    x_max = float(np.max(seq))
    denom = x_max - x_min
    if abs(denom) < eps:
        return np.zeros_like(seq, dtype=np.float64), x_min, x_max
    x_norm = (seq - x_min) / (denom + eps)
    return x_norm, x_min, x_max


def inverse_normalize_seq(norm_seq: np.ndarray, x_min: float, x_max: float, eps: float = 1e-8):
    """
    x = x_min + x_norm * (x_max - x_min)
    """
    norm_seq = np.asarray(norm_seq, dtype=np.float64)
    return x_min + norm_seq * (x_max - x_min + eps)


def angle_diff_deg(a: np.ndarray, b: np.ndarray):
    """
    Mean heading difference (degrees), wrapped to [-180, 180]
    """
    a = np.asarray(a, dtype=np.float64)
    b = np.asarray(b, dtype=np.float64)
    d = (a - b + 180.0) % 360.0 - 180.0
    return float(np.mean(np.abs(d)))


class ImprovedLSTMPredictor(nn.Module):
    """LSTM predicts support point displacement"""
    def __init__(self, input_size: int, hidden_size: int = 170, num_layers: int = 1):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        self.out = nn.Linear(hidden_size, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        lstm_out, _ = self.lstm(x)
        last_output = lstm_out[:, -1, :]
        y = self.out(last_output)
        return y


class MPLSTMInference:
    """MP-LSTM Inference Engine - Using ImprovedLSTMPredictor"""

    def __init__(self, model_path: str, device='cuda:7' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.history_length = 288
        self.prediction_length = 144
        self.support_point_position = 72  # Support point at middle of prediction period

        self.load_models(model_path)

        self.best_reference = None
        self.best_reference_idx = -1

        self.enable_soft_filter = True

    def load_models(self, model_path: str):
        """Load trained models and scalers"""
        with open(f'{model_path}/scalers.pkl', 'rb') as f:
            scalers = pickle.load(f)
            self.scaler_longitude = scalers['scaler_longitude']
            self.scaler_latitude = scalers['scaler_latitude']

            self.scaler_features_lon = scalers.get('scaler_features_lon', None)
            self.scaler_features_lat = scalers.get('scaler_features_lat', None)
            self.scaler_features = scalers.get('scaler_features', None)

        input_size = 3  # [delta, sog, cog]
        self.lstm_model_longitude = ImprovedLSTMPredictor(
            input_size=input_size,
            hidden_size=170,
            num_layers=2
        ).to(self.device)

        self.lstm_model_latitude = ImprovedLSTMPredictor(
            input_size=input_size,
            hidden_size=170,
            num_layers=2
        ).to(self.device)

        self.lstm_model_longitude.load_state_dict(
            torch.load(f'{model_path}/lstm_longitude.pth', map_location=self.device))
        self.lstm_model_latitude.load_state_dict(
            torch.load(f'{model_path}/lstm_latitude.pth', map_location=self.device))

        self.lstm_model_longitude.eval()
        self.lstm_model_latitude.eval()

        print("Models loaded")
        print(f"Model structure: input_size={input_size}, hidden_size=170, num_layers=2")

    def load_reference_database(self, database_path: str):
        """Load reference trajectory database (training set)"""
        if not os.path.exists(database_path):
            raise FileNotFoundError(f"Reference database file not found: {database_path}")

        db = TrajectoryDatabase()
        db.load_database(database_path)
        self.reference_trajectories = db.get_trajectories()
        print(f"Loaded {len(self.reference_trajectories)} reference trajectory segments")

    def load_test_database(self, database_path: str):
        """Load test trajectory database (test set)"""
        if not os.path.exists(database_path):
            raise FileNotFoundError(f"Test database file not found: {database_path}")

        db = TrajectoryDatabase()
        db.load_database(database_path)
        self.test_trajectories = db.get_trajectories()
        print(f"Loaded {len(self.test_trajectories)} test trajectory segments")

    def _soft_candidate_filter(self,
                               cur_dlon_hist: np.ndarray,
                               cur_dlat_hist: np.ndarray,
                               cur_sog_hist: np.ndarray,
                               cur_cog_hist: np.ndarray,
                               ref_dlon_hist: np.ndarray,
                               ref_dlat_hist: np.ndarray,
                               ref_sog_hist: np.ndarray,
                               ref_cog_hist: np.ndarray) -> bool:

        cur_dx = float(cur_dlon_hist[-1] - cur_dlon_hist[0])
        cur_dy = float(cur_dlat_hist[-1] - cur_dlat_hist[0])
        ref_dx = float(ref_dlon_hist[-1] - ref_dlon_hist[0])
        ref_dy = float(ref_dlat_hist[-1] - ref_dlat_hist[0])

        cur_disp = math.hypot(cur_dx, cur_dy)
        ref_disp = math.hypot(ref_dx, ref_dy)

        if cur_disp < 1e-9 and ref_disp < 1e-9:
            pass
        else:
            ratio = (ref_disp + 1e-9) / (cur_disp + 1e-9)
            if ratio < 0.25 or ratio > 4.0:
                return False

        cur_sog_mean = float(np.mean(cur_sog_hist))
        ref_sog_mean = float(np.mean(ref_sog_hist))
        if abs(cur_sog_mean - ref_sog_mean) > 5.0:  # knots
            return False

        cog_diff = angle_diff_deg(cur_cog_hist, ref_cog_hist)
        if cog_diff > 60.0:
            return False

        return True

    def find_reference_trajectory(self, current_trajectory: dict, current_start_idx: int):
        """
        Reference trajectory search
        """

        # Current historical segment (cumulative displacement: from segment's first point to current point)
        cur_dlon_hist = current_trajectory['delta_lons'][current_start_idx:current_start_idx + self.history_length]
        cur_dlat_hist = current_trajectory['delta_lats'][current_start_idx:current_start_idx + self.history_length]
        cur_sog_hist = current_trajectory['sog'][current_start_idx:current_start_idx + self.history_length]
        cur_cog_hist = current_trajectory['cog'][current_start_idx:current_start_idx + self.history_length]

        # Normalization
        cur_dlon_norm, cur_dlon_min, cur_dlon_max = normalize_seq(cur_dlon_hist)
        cur_dlat_norm, cur_dlat_min, cur_dlat_max = normalize_seq(cur_dlat_hist)

        best_mse = float('inf')
        best_ref = None
        best_ref_idx = -1

        for ref_traj in self.reference_trajectories:
            if ref_traj.get('segment_id') == current_trajectory.get('segment_id'):
                continue
            if len(ref_traj['delta_lons']) < self.history_length + self.prediction_length:
                continue

            for start_idx in range(len(ref_traj['delta_lons']) - self.history_length - self.prediction_length + 1):
                ref_dlon_hist = ref_traj['delta_lons'][start_idx:start_idx + self.history_length]
                ref_dlat_hist = ref_traj['delta_lats'][start_idx:start_idx + self.history_length]
                ref_sog_hist = ref_traj['sog'][start_idx:start_idx + self.history_length]
                ref_cog_hist = ref_traj['cog'][start_idx:start_idx + self.history_length]

                if self.enable_soft_filter:
                    if not self._soft_candidate_filter(cur_dlon_hist, cur_dlat_hist, cur_sog_hist, cur_cog_hist,
                                                       ref_dlon_hist, ref_dlat_hist, ref_sog_hist, ref_cog_hist):
                        continue

                # ref historical segment min-max normalization
                ref_dlon_norm, _, _ = normalize_seq(ref_dlon_hist)
                ref_dlat_norm, _, _ = normalize_seq(ref_dlat_hist)

                # MSE in normalized space
                mse_lon = float(np.mean((cur_dlon_norm - ref_dlon_norm) ** 2))
                mse_lat = float(np.mean((cur_dlat_norm - ref_dlat_norm) ** 2))
                total_mse = mse_lon + mse_lat

                if total_mse < best_mse:
                    best_mse = total_mse
                    best_ref = ref_traj
                    best_ref_idx = start_idx

        self.best_reference = best_ref
        self.best_reference_idx = best_ref_idx

        if best_ref is None:
            return None, None, None, None, None, None

        # Take best_ref's future segment (cumulative displacement)
        ref_future_start = best_ref_idx + self.history_length
        ref_future_end = ref_future_start + self.prediction_length
        ref_dlon_future = best_ref['delta_lons'][ref_future_start:ref_future_end]
        ref_dlat_future = best_ref['delta_lats'][ref_future_start:ref_future_end]

        # Future segment: first min-max normalize (according to ref's own future segment range), then denormalize using current historical segment min/max
        ref_dlon_future_norm, _, _ = normalize_seq(ref_dlon_future)
        ref_dlat_future_norm, _, _ = normalize_seq(ref_dlat_future)

        ref_future_delta_lons = inverse_normalize_seq(ref_dlon_future_norm, cur_dlon_min, cur_dlon_max).astype(np.float64)
        ref_future_delta_lats = inverse_normalize_seq(ref_dlat_future_norm, cur_dlat_min, cur_dlat_max).astype(np.float64)

        # Return original lat/lon segments (for visualization/debugging)
        ref_history_start = best_ref_idx
        ref_history_end = ref_history_start + self.history_length
        ref_hist_lons = best_ref['lons'][ref_history_start:ref_history_end]
        ref_hist_lats = best_ref['lats'][ref_history_start:ref_history_end]
        ref_fut_lons = best_ref['lons'][ref_future_start:ref_future_end]
        ref_fut_lats = best_ref['lats'][ref_future_start:ref_future_end]

        return (ref_future_delta_lons, ref_future_delta_lats,
                ref_hist_lons, ref_hist_lats,
                ref_fut_lons, ref_fut_lats)

    def prepare_input_features(self, trajectory: dict, start_idx: int):
        """Prepare input features - keep only first three features"""
        deltas_lon = trajectory['delta_lons'][start_idx:start_idx + self.history_length]
        deltas_lat = trajectory['delta_lats'][start_idx:start_idx + self.history_length]
        sog = trajectory['sog'][start_idx:start_idx + self.history_length]
        cog = trajectory['cog'][start_idx:start_idx + self.history_length]

        features_lon = np.column_stack([deltas_lon, sog, cog])
        features_lat = np.column_stack([deltas_lat, sog, cog])

        return features_lon, features_lat

    def predict_support_point(self, features_lon: np.ndarray, features_lat: np.ndarray):
        """Predict support point displacement"""

        if self.scaler_features_lon is not None:
            input_lon_scaled = self.scaler_features_lon.transform(features_lon).reshape(1, self.history_length, -1)
        elif self.scaler_features is not None:
            input_lon_scaled = self.scaler_features.transform(
                features_lon.reshape(-1, features_lon.shape[1])
            ).reshape(1, self.history_length, -1)
        else:
            raise KeyError("Missing scaler_features_lon or scaler_features in scalers.pkl")

        if self.scaler_features_lat is not None:
            input_lat_scaled = self.scaler_features_lat.transform(features_lat).reshape(1, self.history_length, -1)
        elif self.scaler_features is not None:
            input_lat_scaled = self.scaler_features.transform(
                features_lat.reshape(-1, features_lat.shape[1])
            ).reshape(1, self.history_length, -1)
        else:
            raise KeyError("Missing scaler_features_lat or scaler_features in scalers.pkl")

        input_lon_tensor = torch.as_tensor(input_lon_scaled, dtype=torch.float32, device=self.device)
        input_lat_tensor = torch.as_tensor(input_lat_scaled, dtype=torch.float32, device=self.device)

        with torch.no_grad():
            pred_delta_lon_scaled = self.lstm_model_longitude(input_lon_tensor)
            pred_delta_lat_scaled = self.lstm_model_latitude(input_lat_tensor)

        pred_delta_lon_scaled = pred_delta_lon_scaled.detach().cpu().numpy()
        pred_delta_lat_scaled = pred_delta_lat_scaled.detach().cpu().numpy()

        pred_delta_lon = self.scaler_longitude.inverse_transform(
            pred_delta_lon_scaled.reshape(-1, 1)).flatten()[0]
        pred_delta_lat = self.scaler_latitude.inverse_transform(
            pred_delta_lat_scaled.reshape(-1, 1)).flatten()[0]

        return pred_delta_lon, pred_delta_lat

    def cubic_spline_interpolation(self, start_point: tuple, support_point: tuple,
                                   destination_point: tuple, num_points: int):
        """Cubic spline interpolation to build complete trajectory"""
        t_control = np.array([0, 0.5, 1.0])
        lon_control = np.array([start_point[0], support_point[0], destination_point[0]])
        lat_control = np.array([start_point[1], support_point[1], destination_point[1]])

        t_new = np.linspace(0, 1, num_points)

        cs_lon = CubicSpline(t_control, lon_control)
        cs_lat = CubicSpline(t_control, lat_control)

        predicted_lons = cs_lon(t_new)
        predicted_lats = cs_lat(t_new)

        return predicted_lons, predicted_lats

    def mp_lstm_predict(self, trajectory: dict, start_idx: int = 0):
        """MP-LSTM multi-step prediction"""

        if start_idx + self.history_length + self.prediction_length > len(trajectory['lons']):
            raise ValueError("Trajectory length insufficient for prediction")

        # Current prediction starting point (end of history)
        start_lon = trajectory['lons'][start_idx + self.history_length - 1]
        start_lat = trajectory['lats'][start_idx + self.history_length - 1]

        # First point of current segment (reference for cumulative delta baseline)
        first_lon = trajectory['lons'][start_idx]
        first_lat = trajectory['lats'][start_idx]

        (ref_future_delta_lons, ref_future_delta_lats,
         ref_history_lons, ref_history_lats,
         ref_future_lons, ref_future_lats) = self.find_reference_trajectory(trajectory, start_idx)

        if ref_future_delta_lons is None:
            print("Warning: No suitable reference trajectory found, using simple extrapolation")
            dest_delta_lon = float(np.mean(trajectory['delta_lons'][start_idx:start_idx + self.history_length]))
            dest_delta_lat = float(np.mean(trajectory['delta_lats'][start_idx:start_idx + self.history_length]))
            ref_history_lons = None
            ref_history_lats = None
            ref_future_lons = None
            ref_future_lats = None
        else:
            # In cumulative delta definition: displacement from tm→tn = delta(tn) - delta(tm)
            dest_delta_lon = float(ref_future_delta_lons[-1] - ref_future_delta_lons[0])
            dest_delta_lat = float(ref_future_delta_lats[-1] - ref_future_delta_lats[0])

        # Destination point lat/lon (starting from prediction start point, adding destination displacement)
        dest_lon = start_lon + dest_delta_lon
        dest_lat = start_lat + dest_delta_lat

        # Support point: still using LSTM output (training target should be consistent with cumulative delta definition)
        features_lon, features_lat = self.prepare_input_features(trajectory, start_idx)
        support_delta_lon, support_delta_lat = self.predict_support_point(features_lon, features_lat)

        # support_delta_* is cumulative displacement relative to segment's first point
        support_lon = first_lon + support_delta_lon
        support_lat = first_lat + support_delta_lat

        predicted_lons, predicted_lats = self.cubic_spline_interpolation(
            start_point=(start_lon, start_lat),
            support_point=(support_lon, support_lat),
            destination_point=(dest_lon, dest_lat),
            num_points=self.prediction_length
        )

        true_start_idx = start_idx + self.history_length
        true_end_idx = true_start_idx + self.prediction_length
        true_lons = trajectory['lons'][true_start_idx:true_end_idx]
        true_lats = trajectory['lats'][true_start_idx:true_end_idx]

        orig_lons = trajectory['lons'][start_idx:start_idx + self.history_length]
        orig_lats = trajectory['lats'][start_idx:start_idx + self.history_length]

        half_length = self.prediction_length // 2

        assert len(predicted_lons) == 144 and len(predicted_lats) == 144, "predicted must be length 144"
        assert len(true_lons) == 144 and len(true_lats) == 144, "true must be length 144"

        return {
            'orig_lons': orig_lons,
            'orig_lats': orig_lats,
            'predicted_lons': predicted_lons,
            'predicted_lats': predicted_lats,
            'true_lons': true_lons,
            'true_lats': true_lats,
            'control_points': {
                'start': (start_lon, start_lat),
                'support': (support_lon, support_lat),
                'destination': (dest_lon, dest_lat)
            },
            'reference_used': ref_future_delta_lons is not None,
            'half_predicted_lons': predicted_lons[:half_length],
            'half_predicted_lats': predicted_lats[:half_length],
            'half_true_lons': true_lons[:half_length],
            'half_true_lats': true_lats[:half_length],
            'ref_history_lons': ref_history_lons,
            'ref_history_lats': ref_history_lats,
            'ref_future_lons': ref_future_lons,
            'ref_future_lats': ref_future_lats,
            'best_reference': self.best_reference,
            'best_reference_idx': self.best_reference_idx
        }

    def evaluate_on_test_set(self, num_tests: int = 20, batch_size: int = 64, verbose_every: int = 0):
        """
        Evaluate prediction performance on test set - batch metric calculation for speed
        Each predicted/true trajectory shape is (144, 2)
        """
        if not hasattr(self, 'test_trajectories') or not self.test_trajectories:
            raise ValueError("Please load test trajectory database first")

        print("Evaluating prediction performance on test set (batch metric calculation)...")

        total_loss = 0.0
        total_fd = 0.0
        total_cvt = 0.0
        sample_count = 0

        pred_buf = []
        true_buf = []

        def flush_batch():
            nonlocal total_loss, total_fd, total_cvt, sample_count, pred_buf, true_buf
            if len(pred_buf) == 0:
                return

            pred_bt2 = torch.stack(pred_buf, dim=0)
            true_bt2 = torch.stack(true_buf, dim=0)

            assert pred_bt2.shape[1:] == (144, 2), f"pred batch shape must be (B,144,2), got {pred_bt2.shape}"
            assert true_bt2.shape[1:] == (144, 2), f"true batch shape must be (B,144,2), got {true_bt2.shape}"

            pred_rad = deg_to_rad(pred_bt2)
            true_rad = deg_to_rad(true_bt2)

            critria = torch.nn.MSELoss()
            total_loss += critria(pred_rad, true_rad).item()

            total_fd += frechet_distance(pred_bt2, true_bt2).sum().item()

            total_cvt += torch.mean((curvature_calculation(pred_bt2)["smoothed_curvatures"] -
                                     curvature_calculation(true_bt2)["smoothed_curvatures"]) ** 2).item()

            sample_count += pred_bt2.shape[0]

            pred_buf = []
            true_buf = []

        for i, trajectory in enumerate(tqdm(self.test_trajectories[:])):
            if len(trajectory['lons']) < self.history_length + self.prediction_length:
                continue

            max_start_idx = len(trajectory['lons']) - self.history_length - self.prediction_length
            if max_start_idx < 0:
                continue

            start_idx = random.randint(0, max_start_idx)

            result = self.mp_lstm_predict(trajectory, start_idx)

            pred_np = np.column_stack([result['predicted_lons'], result['predicted_lats']]).astype(np.float32)
            true_np = np.column_stack([result['true_lons'], result['true_lats']]).astype(np.float32)

            assert pred_np.shape == (144, 2), f"single pred must be (144,2), got {pred_np.shape}"
            assert true_np.shape == (144, 2), f"single true must be (144,2), got {true_np.shape}"

            pred_tensor = torch.from_numpy(pred_np)
            true_tensor = torch.from_numpy(true_np)

            pred_buf.append(pred_tensor)
            true_buf.append(true_tensor)

            if len(pred_buf) >= batch_size:
                flush_batch()

            if verbose_every and (i + 1) % verbose_every == 0:
                print(f"Processed up to test trajectory index {i}, current cumulative samples={sample_count + len(pred_buf)}")

        flush_batch()

        if sample_count == 0:
            print("Warning: No successful predictions")
            return None

        avg_loss = total_loss / sample_count
        avg_fd = total_fd / sample_count
        avg_cvt = total_cvt / sample_count

        print(f"\n=== Test Results Summary ===")
        print(f"Successfully tested {sample_count} trajectories")
        print(f"Average position loss (MSE): {avg_loss:.8f}")
        print(f"Average Fréchet distance: {avg_fd:.6f}")
        print(f"Average curvature difference: {avg_cvt:.6f}")

        return {
            'avg_position_loss': avg_loss,
            'avg_frechet_distance': avg_fd,
            'avg_curvature_difference': avg_cvt,
            'sample_count': sample_count
        }

    def plot_prediction_result(self, result: dict, trajectory_id: str = ""):
        """Plot prediction results"""
        plt.figure(figsize=(15, 5))

        plt.subplot(1, 3, 1)
        plt.plot(result['orig_lons'], result['orig_lats'], 'b-', label='History', linewidth=2)
        plt.plot(result['true_lons'], result['true_lats'], 'g-', label='True Future', linewidth=2)
        plt.plot(result['predicted_lons'], result['predicted_lats'], 'r--', label='Predicted Future', linewidth=2)

        cp = result['control_points']
        plt.plot(cp['start'][0], cp['start'][1], "#496292", marker='o', markersize=8, label='Start')
        plt.plot(cp['support'][0], cp['support'][1], "#b7282e", marker='o', markersize=8, label='Support')
        plt.plot(cp['destination'][0], cp['destination'][1], "#376439", marker='o', markersize=8, label='Destination')

        plt.xlabel('Longitude')
        plt.ylabel('Latitude')
        plt.title(f'Trajectory Prediction {trajectory_id}')
        plt.legend()
        plt.grid(True)

        plt.tight_layout()
        plt.show()

        plt.savefig(f"prediction_result_{trajectory_id}.png", dpi=300, bbox_inches='tight')
        print(f"Prediction results saved as: prediction_result_{trajectory_id}.png")


def main():
    """Main inference function"""
    inference = MPLSTMInference('mp_lstm_models2')

    inference.load_reference_database('/LSTM/trajectory_database_288.pkl')

    test_db_path = '/LSTM-NOAA/test_trajectories.pkl'
    inference.load_test_database(test_db_path)

    if not hasattr(inference, 'test_trajectories') or not inference.test_trajectories:
        print("No test trajectory data available")
        return

    for test_idx in range(0, 100, 10):
        test_trajectory = inference.test_trajectories[test_idx]
        result = inference.mp_lstm_predict(test_trajectory, start_idx=0)
        inference.plot_prediction_result(result, f"Test_Trajectory_{test_idx}")

        lon_mse = mean_squared_error(result['true_lons'], result['predicted_lons'])
        lat_mse = mean_squared_error(result['true_lats'], result['predicted_lats'])
        total_mse = (lon_mse + lat_mse) / 2

        print(f"\nSingle trajectory prediction results:")
        print(f"Total MSE: {total_mse:.8f}")
        print(f"Longitude MSE: {lon_mse:.8f}")
        print(f"Latitude MSE: {lat_mse:.8f}")
        print(f"Reference trajectory used: {result['reference_used']}")

    print("\nStarting comprehensive test set evaluation (batch metrics)...")
    evaluation_results = inference.evaluate_on_test_set(num_tests=20, batch_size=128, verbose_every=0)

    if evaluation_results:
        print("\n" + "=" * 50)
        print("MP-LSTM Model Evaluation Results:")
        print(f"Sample count: {evaluation_results['sample_count']}")
        print(f"Average position loss (MSE): {evaluation_results['avg_position_loss']:.8f}")
        print(f"Average Fréchet distance: {evaluation_results['avg_frechet_distance']:.6f}")
        print(f"Average curvature difference: {evaluation_results['avg_curvature_difference']:.6f}")
        print("=" * 50)


if __name__ == "__main__":
    main()