import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict, Any
import pickle
import warnings
from sklearn.model_selection import train_test_split
warnings.filterwarnings('ignore')

from create_database import TrajectoryDatabase

import torch
import torch.nn as nn


class ImprovedLSTMPredictor(nn.Module):
    """
    LSTM-based regression model for predicting future displacement
    using historical sequential features.

    Model structure:
    - LSTM for temporal feature extraction
    - Linear layer for scalar regression output
    """
    def __init__(self, input_size: int, hidden_size: int = 170, num_layers: int = 1):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # LSTM layer processing sequential input features
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )

        # Linear layer mapping the last hidden state to a scalar output
        self.out = nn.Linear(hidden_size, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass

        Args:
            x: Tensor of shape [batch_size, sequence_length, input_size]

        Returns:
            Tensor of shape [batch_size, 1]
        """
        lstm_out, _ = self.lstm(x)              # [batch, seq_len, hidden_size]
        last_output = lstm_out[:, -1, :]        # last time step
        y = self.out(last_output)               # regression output
        return y


class MP_LSTMTrainer:
    """
    Trainer class handling data preparation, model training,
    evaluation split, and model persistence.
    """
    
    def __init__(self, device='cuda:5' if torch.cuda.is_available() else 'cpu'):
        self.device = device

        # Scalers for target variables
        self.scaler_longitude = MinMaxScaler()
        self.scaler_latitude = MinMaxScaler()

        # Separate feature scalers for longitude and latitude inputs
        self.scaler_features_lon = MinMaxScaler()
        self.scaler_features_lat = MinMaxScaler()
        
        self.lstm_model_longitude = None
        self.lstm_model_latitude = None

    def load_and_split_trajectory_database(self, database_path: str, test_ratio: float = 0.2, random_seed: int = 42):
        """
        Load trajectory database and split into training and test sets
        at the trajectory level.
        """
        db = TrajectoryDatabase()
        db.load_database(database_path)
        all_trajectories = db.get_trajectories()
        
        train_trajectories, test_trajectories = train_test_split(
            all_trajectories,
            test_size=test_ratio,
            random_state=random_seed
        )
        
        self.train_trajectories = train_trajectories
        self.test_trajectories = test_trajectories
        
        self.trajectories = self.train_trajectories
        
        print(f"Dataset split completed:")
        print(f"  Total trajectories: {len(all_trajectories)}")
        print(f"  Training set: {len(self.train_trajectories)} trajectories")
        print(f"  Test set: {len(self.test_trajectories)} trajectories")
        
    def save_test_trajectories(self, filepath: str):
        """
        Save test trajectories for later evaluation or inference.
        """
        test_data = {
            'test_trajectories': self.test_trajectories,
            'description': 'Test trajectories for evaluation'
        }
        
        with open(filepath, 'wb') as f:
            pickle.dump(self.test_trajectories, f)
        
        print(f"Test trajectories saved to: {filepath}")
        
    def prepare_training_data(self, history_length=10, prediction_length=5):
        """
        Prepare supervised training samples from trajectory data.

        Input:
            - Historical sequence of fixed length
        Target:
            - Future displacement at a selected time step
        """
        print("Preparing training data...")
        
        X_longitude, y_longitude = [], []
        X_latitude, y_latitude = [], []
        
        for trajectory in self.trajectories:
            deltas_lon = trajectory['delta_lons']
            deltas_lat = trajectory['delta_lats']
            sog = trajectory['sog']
            cog = trajectory['cog']

            # Skip trajectories that are too short
            if len(deltas_lon) < history_length + prediction_length:
                continue
                
            for i in range(history_length, len(deltas_lon) - prediction_length + 1):
                # Longitude feature sequence
                features_lon = np.column_stack([
                    deltas_lon[i-history_length:i],
                    sog[i-history_length:i],
                    cog[i-history_length:i],
                ])
                
                # Target index selected within prediction window
                support_point_idx = i + prediction_length // 2
                if support_point_idx < len(deltas_lon):
                    X_longitude.append(features_lon)
                    y_longitude.append(deltas_lon[support_point_idx])
                    
                    # Latitude feature sequence
                    features_lat = np.column_stack([
                        deltas_lat[i-history_length:i],
                        sog[i-history_length:i],
                        cog[i-history_length:i],
                    ])
                    
                    X_latitude.append(features_lat)
                    y_latitude.append(deltas_lat[support_point_idx])
        
        self.X_lon = np.array(X_longitude)
        self.y_lon = np.array(y_longitude)
        self.X_lat = np.array(X_latitude)
        self.y_lat = np.array(y_latitude)
        
        print(f"Training data preparation completed:")
        print(f"  Longitude data: {self.X_lon.shape}, target range: [{self.y_lon.min():.6f}, {self.y_lon.max():.6f}]")
        print(f"  Latitude data: {self.X_lat.shape}, target range: [{self.y_lat.min():.6f}, {self.y_lat.max():.6f}]")
        print(f"  Input feature count: {self.X_lon.shape[2]}")

    def build_models(self, hidden_units=256):
        """
        Initialize LSTM models for longitude and latitude prediction.
        """
        print("Building LSTM models...")
        
        input_size = self.X_lon.shape[2]
        
        self.lstm_model_longitude = ImprovedLSTMPredictor(
            input_size=input_size, 
            hidden_size=hidden_units,
            num_layers=2
        ).to(self.device)
        
        self.lstm_model_latitude = ImprovedLSTMPredictor(
            input_size=input_size, 
            hidden_size=hidden_units,
            num_layers=2
        ).to(self.device)
        
        print(f"Model input size: {input_size}, hidden units: {hidden_units}")
        print("LSTM models built")
        
    def train_models(self, epochs=250, batch_size=32, learning_rate=0.0001):
        """
        Train longitude and latitude models using normalized data,
        mean squared error loss, and adaptive learning rate scheduling.
        """
        print("Starting model training...")
        
        X_lon_reshaped = self.X_lon.reshape(-1, self.X_lon.shape[2])
        X_lat_reshaped = self.X_lat.reshape(-1, self.X_lat.shape[2])
        
        self.X_lon_scaled = self.scaler_features_lon.fit_transform(X_lon_reshaped)
        self.X_lon_scaled = self.X_lon_scaled.reshape(self.X_lon.shape)
        
        self.X_lat_scaled = self.scaler_features_lat.fit_transform(X_lat_reshaped)
        self.X_lat_scaled = self.X_lat_scaled.reshape(self.X_lat.shape)
        
        self.y_lon_scaled = self.scaler_longitude.fit_transform(self.y_lon.reshape(-1, 1)).flatten()
        self.y_lat_scaled = self.scaler_latitude.fit_transform(self.y_lat.reshape(-1, 1)).flatten()
        
        X_lon_tensor = torch.FloatTensor(self.X_lon_scaled).to(self.device)
        y_lon_tensor = torch.FloatTensor(self.y_lon_scaled).to(self.device)
        X_lat_tensor = torch.FloatTensor(self.X_lat_scaled).to(self.device)
        y_lat_tensor = torch.FloatTensor(self.y_lat_scaled).to(self.device)
        
        lon_dataset = TensorDataset(X_lon_tensor, y_lon_tensor)
        lat_dataset = TensorDataset(X_lat_tensor, y_lat_tensor)
        
        lon_loader = DataLoader(lon_dataset, batch_size=batch_size, shuffle=True)
        lat_loader = DataLoader(lat_dataset, batch_size=batch_size, shuffle=True)
        
        criterion = nn.MSELoss()
        optimizer_lon = optim.Adam(self.lstm_model_longitude.parameters(), lr=learning_rate)
        optimizer_lat = optim.Adam(self.lstm_model_latitude.parameters(), lr=learning_rate)
        
        scheduler_lon = optim.lr_scheduler.ReduceLROnPlateau(optimizer_lon, mode='min', factor=0.5, patience=50)
        scheduler_lat = optim.lr_scheduler.ReduceLROnPlateau(optimizer_lat, mode='min', factor=0.5, patience=50)
        
        print("Training longitude model...")
        self.lstm_model_longitude.train()
        lon_losses = []
        
        for epoch in range(epochs):
            epoch_loss = 0
            for batch_X, batch_y in lon_loader:
                optimizer_lon.zero_grad()
                outputs = self.lstm_model_longitude(batch_X)
                loss = criterion(outputs.squeeze(), batch_y)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.lstm_model_longitude.parameters(), max_norm=1.0)
                optimizer_lon.step()
                epoch_loss += loss.item()
            
            avg_loss = epoch_loss / len(lon_loader)
            lon_losses.append(avg_loss)
            scheduler_lon.step(avg_loss)
            
            if epoch % 10 == 0:
                current_lr = optimizer_lon.param_groups[0]['lr']
                print(f'  Longitude model Epoch [{epoch}/{epochs}], Loss: {avg_loss:.6f}, LR: {current_lr:.6f}')
        
        print("Training latitude model...")
        self.lstm_model_latitude.train()
        lat_losses = []
        
        for epoch in range(epochs):
            epoch_loss = 0
            for batch_X, batch_y in lat_loader:
                optimizer_lat.zero_grad()
                outputs = self.lstm_model_latitude(batch_X)
                loss = criterion(outputs.squeeze(), batch_y)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.lstm_model_latitude.parameters(), max_norm=1.0)
                optimizer_lat.step()
                epoch_loss += loss.item()
            
            avg_loss = epoch_loss / len(lat_loader)
            lat_losses.append(avg_loss)
            scheduler_lat.step(avg_loss)
            
            if epoch % 10 == 0:
                current_lr = optimizer_lat.param_groups[0]['lr']
                print(f'  Latitude model Epoch [{epoch}/{epochs}], Loss: {avg_loss:.6f}, LR: {current_lr:.6f}')
    
        self.training_history = {
            'longitude_losses': lon_losses, 
            'latitude_losses': lat_losses
        }
        print("Model training completed")
    
    def save_models(self, model_path: str):
        """
        Save trained models and associated scalers.
        """
        import os
        os.makedirs(model_path, exist_ok=True)
        
        torch.save(self.lstm_model_longitude.state_dict(), f'{model_path}/lstm_longitude.pth')
        torch.save(self.lstm_model_latitude.state_dict(), f'{model_path}/lstm_latitude.pth')
        
        with open(f'{model_path}/scalers.pkl', 'wb') as f:
            pickle.dump({
                'scaler_longitude': self.scaler_longitude,
                'scaler_latitude': self.scaler_latitude,
                'scaler_features_lon': self.scaler_features_lon,
                'scaler_features_lat': self.scaler_features_lat
            }, f)
        
        print(f"Models and scalers saved to {model_path}")


def main():
    """
    Entry point for model training.
    """
    trainer = MP_LSTMTrainer()
    
    database_path = '11.9/LSTM/trajectory_for_inference.pkl'
    trainer.load_and_split_trajectory_database(database_path, test_ratio=0.2)
    
    test_save_path = '11.9/LSTM/test_trajectories.pkl'
    trainer.save_test_trajectories(test_save_path)
    
    trainer.prepare_training_data(history_length=288, prediction_length=144)
    
    trainer.build_models(hidden_units=170)
    trainer.train_models(epochs=200, batch_size=32, learning_rate=0.0001)
    
    trainer.save_models('mp_lstm_models2')


if __name__ == "__main__":
    main()
