import torch
from torch import nn
from torch.nn import functional as F


class SkipDense(nn.Module):
    """Dense Layer with skip connection in PyTorch."""

    def __init__(self, units):
        super(SkipDense, self).__init__()
        self.hidden = nn.Linear(units, units)
        # Initialize weights using He normal initialization
        nn.init.kaiming_normal_(self.hidden.weight)

    def forward(self, x):
        return self.hidden(x) + x


# Torch
class PolicyNetwork(nn.Module):
    """Implements the policy network as an MLP with skip connections in PyTorch."""

    def __init__(
        self,
        input_size,
        policy_network_layers,
        num_actions,
        activation="leakyrelu",
        device="cpu",
    ):
        super(PolicyNetwork, self).__init__()
        self.input_size = input_size
        self.num_actions = num_actions

        # Choose activation function
        if activation == "leakyrelu":
            self.activation = nn.LeakyReLU(0.2)
        elif activation == "relu":
            self.activation = nn.ReLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")

        self.hidden = nn.ModuleList()
        prev_units = input_size
        for units in policy_network_layers[:-1]:
            if prev_units == units:
                self.hidden.append(SkipDense(units))
            else:
                self.hidden.append(nn.Linear(prev_units, units))
                nn.init.kaiming_normal_(self.hidden[-1].weight)
            prev_units = units

        # Layer normalization before the last layer
        self.normalization = nn.LayerNorm(policy_network_layers[-2])
        self.last_layer = nn.Linear(
            policy_network_layers[-2], policy_network_layers[-1]
        )
        nn.init.kaiming_normal_(self.last_layer.weight)

        # Output layer
        self.out_layer = nn.Linear(policy_network_layers[-1], num_actions)
        self.model = nn.Sequential(
            *self.hidden, self.normalization, self.last_layer, self.out_layer
        )
        self.device = device

    def forward(self, inputs, mask=None):
        if isinstance(inputs, tuple):
            x, mask = inputs
        elif mask is None:
            x = inputs
            mask = torch.tensor(1)
        x = self.model(x)
        # Applying mask
        # Set logits for illegal actions to a very large negative value to negate their effect in softmax
        x = torch.where(mask == 1, x, torch.tensor(-10e20))  # .to(x.device))

        # Softmax to get probabilities
        x = F.softmax(x, dim=-1)
        return x


# torch
class RegretNetwork(nn.Module):
    def __init__(
        self,
        input_size,
        regret_network_layers,
        num_actions,
        activation="leakyrelu",
        device="cpu",
    ):
        super(RegretNetwork, self).__init__()
        self.input_size = input_size
        self.num_actions = num_actions
        self.activation = nn.LeakyReLU(0.2) if activation == "leakyrelu" else nn.ReLU()

        layers = []
        prev_units = input_size
        for units in regret_network_layers[:-1]:
            if prev_units == units:
                layers.append(SkipDense(units))
            else:
                layer = nn.Linear(prev_units, units)
                nn.init.kaiming_normal_(layer.weight)
                layers.append(layer)
            prev_units = units
        self.layers = nn.ModuleList(layers)

        self.normalization = nn.LayerNorm(regret_network_layers[-2])
        self.last_layer = nn.Linear(
            regret_network_layers[-2], regret_network_layers[-1]
        )
        nn.init.kaiming_normal_(self.last_layer.weight)
        self.out_layer = nn.Linear(regret_network_layers[-1], num_actions)
        self.device = device

    def forward(self, inputs):
        x, mask = inputs
        for layer in self.layers:
            x = self.activation(layer(x))
        x = self.normalization(x)
        x = self.activation(self.last_layer(x))
        x = self.out_layer(x)
        x = mask * x
        return x


class ValueNetwork(nn.Module):
    def __init__(
        self, input_size, val_network_layers, activation="leakyrelu", device="cpu"
    ):
        super(ValueNetwork, self).__init__()
        self.input_size = input_size
        self.activation = nn.LeakyReLU(0.2) if activation == "leakyrelu" else nn.ReLU()

        layers = []
        prev_units = input_size
        for units in val_network_layers[:-1]:
            if prev_units == units:
                layers.append(SkipDense(units))
            else:
                layer = nn.Linear(prev_units, units)
                nn.init.kaiming_normal_(layer.weight)
                layers.append(layer)
            prev_units = units
        self.layers = nn.ModuleList(layers)

        self.normalization = nn.LayerNorm(val_network_layers[-2])
        self.last_layer = nn.Linear(val_network_layers[-2], val_network_layers[-1])
        nn.init.kaiming_normal_(self.last_layer.weight)
        self.out_layer = nn.Linear(val_network_layers[-1], 1)
        self.device = device
        # self.to(device)

    def forward(self, inputs):
        x, _ = inputs  # Mask is not used in ValueNetwork
        x = x.float()  # .to(self.device)
        for layer in self.layers:
            x = self.activation(layer(x))
        x = self.normalization(x)
        x = self.activation(self.last_layer(x))
        x = self.out_layer(x)
        return x
