import torch
import torch.nn as nn
import torch.optim as optim

# Detect device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class TrajectoryGRU(nn.Module):
    def __init__(
        self,
        state_dim, 
        hidden_dim,
        m,  
        k, 
        action_type="fb_cos_sin",
        Probabilistic=False,
        dropout=False,
    ):

        super(TrajectoryGRU, self).__init__()
        self.Probabilistic = Probabilistic
        self.action_type = action_type
        self.m = m
        self.k = k
        self.hidden_dim = hidden_dim

        self.gru = nn.GRU(
            input_size=state_dim, hidden_size=hidden_dim, num_layers=1, batch_first=True
        )


        self.do_dropout = dropout
        if dropout:
            self.dropout = nn.Dropout(0.1)
        else:
            self.dropout = None


        self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc2 = nn.Linear(hidden_dim // 2, hidden_dim // 4)
        self.relu = nn.Tanh() 


        if self.Probabilistic:
            output_size = 4 * k
        else:
            if action_type == "fb_cos_sin":
                self.output_mse_size = 3
            elif action_type == "cos_sin":
                self.output_mse_size = 2
            elif action_type == "ce_cos_sin":
                self.output_mse_size = 5
            elif action_type == "ce_cos_sin_speed_direction":
                self.output_mse_size = 5
            output_size = self.output_mse_size * k

        self.fc3_mse = nn.Linear(hidden_dim // 4, output_size)

        self.fc3_ce = nn.Linear(hidden_dim // 4, 2 * k)

    def forward(self, prev_states, current_state, return_inner_state=False):


        current_state = current_state.unsqueeze(1)  # [batch_size, 1, state_dim]
        x = torch.cat([prev_states, current_state], dim=1)

        batch_size = x.size(0)
        h0 = torch.randn(1, batch_size, self.hidden_dim, device=x.device)
        out, hidden = self.gru(x, h0)

        h_last = hidden[0]


        x = self.fc1(h_last)
        x = self.relu(x)
        if self.do_dropout:
            x = self.dropout(x)

        x = self.fc2(x)
        x = self.relu(x)

        if return_inner_state:
            inner_state = x
        else:
            inner_state = None

        x_for_mse = self.fc3_mse(x)  # [batch_size, output_size]
        x_for_ce = self.fc3_ce(x)  # [batch_size, 2*k]

        mse_output = x_for_mse.reshape(x.size(0), -1, self.output_mse_size)

        ce_logits = x_for_ce.reshape(x.size(0), -1, 2)

        if self.Probabilistic:
            mean, log_std = mse_output.chunk(2, dim=1)
            std = torch.exp(log_std)
            return mean, std, ce_logits
        else:
            return mse_output, ce_logits, inner_state

