import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from typing import Dict, List, Optional, Tuple


class StateEncoder(nn.Module):
    """Base class for state encoders used in AURORA."""

    def __init__(
        self, state_dim: int, latent_dim: int, hidden_sizes: List[int], device: str
    ):
        """Initialize the encoder.

        Args:
            state_dim: Dimension of input states
            latent_dim: Dimension of latent space
            hidden_sizes: List of hidden layer sizes
            device: Device to run on
        """
        super().__init__()
        self.state_dim = state_dim
        self.latent_dim = latent_dim
        self.device = device

        self.encoder = self._build_encoder(hidden_sizes)
        self.decoder = self._build_decoder(hidden_sizes[::-1])

        # Initialize weights
        self.apply(self._init_weights)
        self.to(device)

    def _init_weights(self, module: nn.Module) -> None:
        """Initialize network weights using orthogonal initialization."""
        if isinstance(module, nn.Linear):
            gain = np.sqrt(2.0)
            nn.init.orthogonal_(module.weight, gain=gain)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

    def _build_encoder(self, hidden_sizes: List[int]) -> nn.Sequential:
        """Build encoder network."""
        layers = []
        prev_size = self.state_dim

        for size in hidden_sizes:
            layers.extend(
                [
                    nn.Linear(prev_size, size),
                    nn.ReLU(),
                ]
            )
            prev_size = size

        layers.append(nn.Linear(prev_size, self.latent_dim))
        return nn.Sequential(*layers)

    def _build_decoder(self, hidden_sizes: List[int]) -> nn.Sequential:
        """Build decoder network."""
        layers = []
        prev_size = self.latent_dim

        for size in hidden_sizes:
            layers.extend(
                [
                    nn.Linear(prev_size, size),
                    nn.ReLU(),
                ]
            )
            prev_size = size

        layers.append(nn.Linear(prev_size, self.state_dim))
        return nn.Sequential(*layers)

    def encode(self, states: torch.Tensor) -> torch.Tensor:
        """Encode states into latent space.

        Args:
            states: States to encode [batch_size x state_dim]

        Returns:
            Latent vectors [batch_size x latent_dim]
        """
        states = states.to(self.device)
        return self.encoder(states)

    def decode(self, latents: torch.Tensor) -> torch.Tensor:
        """Decode latent vectors back to states.

        Args:
            latents: Latent vectors to decode [batch_size x latent_dim]

        Returns:
            Reconstructed states [batch_size x state_dim]
        """
        latents = latents.to(self.device)
        return self.decoder(latents)

    def forward(self, states: torch.Tensor) -> torch.Tensor:
        """Forward pass through encoder and decoder.

        Args:
            states: Input states [batch_size x state_dim]

        Returns:
            Reconstructed states [batch_size x state_dim]
        """
        latents = self.encode(states)
        reconstructed = self.decode(latents)
        return reconstructed

    def update(
        self,
        states: torch.Tensor,
        batch_size: int,
        epochs: int,
        learning_rate: float,
        validation_split: float,
        patience: int,
    ) -> Dict[str, float]:
        """Train the autoencoder.

        Args:
            states: States to train on [n_states x state_dim]
            batch_size: Training batch size
            epochs: Number of epochs to train
            learning_rate: Learning rate for optimization
            validation_split: Fraction of data to use for validation
            patience: Early stopping patience

        Returns:
            Dictionary of training metrics
        """
        states = states.to(self.device)

        # Split data into train/val
        n_val = int(len(states) * validation_split)
        indices = torch.randperm(len(states))
        train_idx, val_idx = indices[n_val:], indices[:n_val]

        train_states = states[train_idx]
        val_states = states[val_idx]

        # Create data loaders
        train_loader = DataLoader(
            TensorDataset(train_states),
            batch_size=batch_size,
            shuffle=True,
        )
        val_loader = DataLoader(
            TensorDataset(val_states),
            batch_size=batch_size,
        )

        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        best_val_loss = float("inf")
        patience_counter = 0
        train_losses = []
        val_losses = []

        for epoch in range(epochs):
            # Training
            self.train()
            epoch_loss = 0.0
            for batch in train_loader:
                states_batch = batch[0]
                optimizer.zero_grad()
                reconstructed = self(states_batch)
                loss = F.mse_loss(reconstructed, states_batch)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
            train_losses.append(epoch_loss / len(train_loader))

            # Validation
            self.eval()
            val_loss = 0.0
            with torch.no_grad():
                for batch in val_loader:
                    states_batch = batch[0]
                    reconstructed = self(states_batch)
                    loss = F.mse_loss(reconstructed, states_batch)
                    val_loss += loss.item()
            val_loss /= len(val_loader)
            val_losses.append(val_loss)

            # Early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    break

        return {
            "train_loss": train_losses[-1],
            "val_loss": val_losses[-1],
            "best_val_loss": best_val_loss,
            "epochs": epoch + 1,
        }


class TrajectoryEncoder(nn.Module):
    """LSTM-based encoder for trajectories used in AURORA."""

    def __init__(
        self,
        state_dim: int,
        latent_dim: int,
        hidden_dim: int,
        num_layers: int,
        device: str = "cpu",
        teacher_force: bool = True,
    ):
        """Initialize the trajectory encoder.

        Args:
            state_dim: Dimension of input states
            latent_dim: Dimension of latent space
            hidden_dim: Hidden dimension of LSTM
            num_layers: Number of LSTM layers
            device: Device to run on
            teacher_force: Whether to use teacher forcing during decoding
        """
        super().__init__()
        self.state_dim = state_dim
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.device = device
        self.teacher_force = teacher_force

        # Encoder LSTM
        self.encoder_lstm = nn.LSTM(
            input_size=state_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
        )

        # Final encoder layer to map to latent space
        self.encoder_fc = nn.Linear(hidden_dim, latent_dim)

        # Decoder initial state projection
        self.decoder_init = nn.Linear(latent_dim, hidden_dim * num_layers)

        # Decoder LSTM
        self.decoder_lstm = nn.LSTM(
            input_size=state_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
        )

        # Output projection
        self.decoder_fc = nn.Linear(hidden_dim, state_dim)

        # Initialize weights
        self.apply(self._init_weights)
        self.to(device)

    def _init_weights(self, module: nn.Module) -> None:
        """Initialize network weights using orthogonal initialization."""
        if isinstance(module, nn.Linear):
            gain = np.sqrt(2.0)
            nn.init.orthogonal_(module.weight, gain=gain)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LSTM):
            for name, param in module.named_parameters():
                if "weight" in name:
                    nn.init.orthogonal_(param, gain=1.0)
                elif "bias" in name:
                    nn.init.zeros_(param)

    def encode(self, trajectories: List[torch.Tensor]) -> torch.Tensor:
        """Encode trajectories into latent space.

        Args:
            trajectories: List of trajectory tensors [batch_size x List of (T, state_dim)]

        Returns:
            Latent vectors [batch_size x latent_dim]
        """
        batch_size = len(trajectories)

        # Pack trajectories for efficient computation
        lengths = [traj.shape[0] for traj in trajectories]
        max_len = max(lengths)

        # Create padded batch
        padded_trajectories = torch.zeros(
            (batch_size, max_len, self.state_dim), device=self.device
        )

        for i, traj in enumerate(trajectories):
            padded_trajectories[i, : traj.shape[0], :] = traj.to(self.device)

        # Pack the padded sequences
        packed_trajectories = nn.utils.rnn.pack_padded_sequence(
            padded_trajectories, lengths, batch_first=True, enforce_sorted=False
        )

        # Encode trajectories
        _, (hidden, _) = self.encoder_lstm(packed_trajectories)

        # Use the final hidden state from the last layer
        last_hidden = hidden[-1]  # Shape: [batch_size, hidden_dim]

        # Project to latent space
        latent = self.encoder_fc(last_hidden)  # Shape: [batch_size, latent_dim]

        return latent

    def _init_decoder_state(
        self, latent: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Initialize decoder hidden and cell states from latent vector.

        Args:
            latent: Latent vectors [batch_size x latent_dim]

        Returns:
            Tuple of (hidden_state, cell_state) for decoder LSTM
        """
        batch_size = latent.shape[0]

        # Project latent to hidden dimension
        init_hidden = self.decoder_init(latent)  # [batch_size, hidden_dim * num_layers]

        # Reshape to match LSTM hidden state format
        init_hidden = init_hidden.view(batch_size, self.num_layers, self.hidden_dim)
        init_hidden = init_hidden.transpose(
            0, 1
        ).contiguous()  # [num_layers, batch_size, hidden_dim]

        # Initialize cell state with zeros
        init_cell = torch.zeros_like(init_hidden)

        return init_hidden, init_cell

    def decode(
        self,
        latent: torch.Tensor,
        trajectory_lengths: List[int],
        initial_states: Optional[List[torch.Tensor]] = None,
    ) -> List[torch.Tensor]:
        """Decode latent vectors back to trajectories.

        Args:
            latent: Latent vectors [batch_size x latent_dim]
            trajectory_lengths: List of trajectory lengths to generate
            initial_states: Optional initial states for each trajectory [batch_size x state_dim]

        Returns:
            List of reconstructed trajectories [batch_size x List of (T, state_dim)]
            When initial_states is provided, the first state in each trajectory is the true initial state.
        """
        batch_size = latent.shape[0]
        max_len = max(trajectory_lengths)

        # Initialize decoder states
        hidden, cell = self._init_decoder_state(latent)

        # Storage for output trajectory
        outputs = torch.zeros((batch_size, max_len, self.state_dim), device=self.device)

        # If initial states are provided
        if initial_states is not None:
            # Put the true initial states as the first element in the output
            for i, state in enumerate(initial_states):
                outputs[i, 0, :] = state.to(self.device)

            # Use initial states as input to generate the next state
            current_input = torch.stack(
                [s.unsqueeze(0) for s in initial_states], dim=0
            ).to(
                self.device
            )  # B, 1, state_dim
            start_pos = 1  # Start generating from position 1
        else:
            # If no initial states, start with zeros
            current_input = torch.zeros(
                (batch_size, 1, self.state_dim), device=self.device
            )
            start_pos = 0  # Start generating from position 0

        # Generate trajectory step by step
        for t in range(start_pos, max_len):
            # Pass through LSTM
            lstm_out, (hidden, cell) = self.decoder_lstm(current_input, (hidden, cell))

            # Project to state space
            current_output = self.decoder_fc(
                lstm_out.squeeze(1)
            )  # [batch_size, state_dim]

            # Store outputs
            outputs[:, t, :] = current_output

            # Next input is the current output (autoregressive)
            current_input = current_output.unsqueeze(1)  # [batch_size, 1, state_dim]

        # Convert to list of trajectories with proper lengths
        reconstructed = []
        for i in range(batch_size):
            reconstructed.append(outputs[i, : trajectory_lengths[i], :])

        return reconstructed

    def decode_teacher_forcing(
        self, latent: torch.Tensor, trajectories: List[torch.Tensor]
    ) -> List[torch.Tensor]:
        """Decode with teacher forcing for training.

        Args:
            latent: Latent vectors [batch_size x latent_dim]
            trajectories: List of trajectory tensors [batch_size x List of (T, state_dim)]

        Returns:
            List of reconstructed trajectories [batch_size x List of (T, state_dim)]
        """
        batch_size = len(trajectories)
        lengths = [traj.shape[0] for traj in trajectories]
        max_len = max(lengths)

        # Initialize decoder states
        hidden, cell = self._init_decoder_state(latent)

        # Prepare input (each trajectory's initial state)
        padded_inputs = torch.zeros(
            (batch_size, max_len, self.state_dim), device=self.device
        )

        # Teacher forcing uses the ground truth sequence shifted by 1 as input
        for i, traj in enumerate(trajectories):
            # For all trajectories, we use zeros as the first input
            # Then, we use the ground truth sequence (except the last state)
            if traj.shape[0] > 1:
                padded_inputs[i, 1 : traj.shape[0], :] = traj[:-1, :]

        # Pack sequences
        packed_inputs = nn.utils.rnn.pack_padded_sequence(
            padded_inputs, lengths, batch_first=True, enforce_sorted=False
        )

        # Process through LSTM
        packed_outputs, _ = self.decoder_lstm(packed_inputs, (hidden, cell))

        # Unpack outputs
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True)

        # Project to state space
        outputs = self.decoder_fc(outputs)  # [batch_size, max_len, state_dim]

        # Convert to list of trajectories with proper lengths
        reconstructed = []
        for i in range(batch_size):
            reconstructed.append(outputs[i, : lengths[i], :])

        return reconstructed

    def forward(self, trajectories: List[torch.Tensor]) -> List[torch.Tensor]:
        """Forward pass through encoder and decoder.

        Args:
            trajectories: List of trajectory tensors [batch_size x List of (T, state_dim)]

        Returns:
            Reconstructed trajectories [batch_size x List of (T, state_dim)]
        """
        # Move trajectories to device
        device_trajectories = [traj.to(self.device) for traj in trajectories]

        # Encode trajectories
        latent = self.encode(device_trajectories)

        # Decode based on teacher_force setting
        if self.teacher_force:
            reconstructed = self.decode_teacher_forcing(latent, device_trajectories)
        else:
            lengths = [traj.shape[0] for traj in device_trajectories]
            initial_states = [traj[0] for traj in device_trajectories]
            reconstructed = self.decode(latent, lengths, initial_states)

        return reconstructed

    def update(
        self,
        trajectories: List[torch.Tensor],
        batch_size: int,
        epochs: int,
        learning_rate: float,
        validation_split: float,
        patience: int,
    ) -> Dict[str, float]:
        """Train the autoencoder.

        Args:
            trajectories: List of trajectory tensors [n_trajectories x List of (T, state_dim)]
            batch_size: Training batch size
            epochs: Number of epochs to train
            learning_rate: Learning rate for optimization
            validation_split: Fraction of data to use for validation
            patience: Early stopping patience

        Returns:
            Dictionary of training metrics
        """
        n_trajectories = len(trajectories)

        # Split data into train/val
        n_val = int(n_trajectories * validation_split)
        indices = torch.randperm(n_trajectories).tolist()
        train_idx, val_idx = indices[n_val:], indices[:n_val]

        train_trajectories = [trajectories[i] for i in train_idx]
        val_trajectories = [trajectories[i] for i in val_idx]

        # Create data loaders
        train_loader = DataLoader(
            train_trajectories,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=lambda batch: batch,
        )
        val_loader = DataLoader(
            val_trajectories,
            batch_size=batch_size,
            collate_fn=lambda batch: batch,
        )

        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        best_val_loss = float("inf")
        patience_counter = 0
        train_losses = []
        val_losses = []

        for epoch in range(epochs):
            # Training
            self.train()
            epoch_loss = 0.0
            num_batches = 0

            for batch in train_loader:
                optimizer.zero_grad()

                # Forward pass
                reconstructed = self(batch)

                # Compute loss
                batch_loss = 0.0
                for i, (pred, target) in enumerate(zip(reconstructed, batch)):
                    # When not using teacher forcing, exclude the first state (true initial state)
                    if not self.teacher_force:
                        pred = pred[1:]
                        target = target[1:].to(self.device)
                    else:
                        target = target.to(self.device)

                    if pred.shape[0] > 0:  # Ensure there's data to compute loss on
                        batch_loss += F.mse_loss(pred, target)

                loss = batch_loss / len(batch) if len(batch) > 0 else 0.0

                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                num_batches += 1

            train_losses.append(epoch_loss / num_batches if num_batches > 0 else 0.0)

            # Validation
            self.eval()
            val_loss = 0.0
            num_val_batches = 0

            with torch.no_grad():
                for batch in val_loader:
                    reconstructed = self(batch)

                    # Compute loss
                    batch_loss = 0.0
                    for i, (pred, target) in enumerate(zip(reconstructed, batch)):
                        # When not using teacher forcing, exclude the first state (true initial state)
                        if not self.teacher_force:
                            pred = pred[1:]
                            target = target[1:].to(self.device)
                        else:
                            target = target.to(self.device)

                        if pred.shape[0] > 0:  # Ensure there's data to compute loss on
                            batch_loss += F.mse_loss(pred, target)

                    batch_loss = batch_loss / len(batch) if len(batch) > 0 else 0.0
                    val_loss += batch_loss.item()
                    num_val_batches += 1

            val_loss = val_loss / num_val_batches if num_val_batches > 0 else 0.0
            val_losses.append(val_loss)

            # Early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    break

        return {
            "train_loss": train_losses[-1],
            "val_loss": val_losses[-1],
            "best_val_loss": best_val_loss,
            "epochs": epoch + 1,
        }
