import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import init


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


class CnnBCetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()

        linear = nn.Linear

        self.feature = nn.Sequential(
            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
            nn.ReLU(),
            Flatten(),
            linear(7 * 7 * 64, 256),
            nn.ReLU(),
            linear(256, 448),
            nn.ReLU(),
        )

        self.actor = nn.Sequential(
            linear(448, 448), nn.ReLU(), linear(448, output_size)
        )

        for p in self.modules():
            if isinstance(p, nn.Conv2d):
                init.orthogonal_(p.weight, np.sqrt(2))
                p.bias.data.zero_()

            if isinstance(p, nn.Linear):
                init.orthogonal_(p.weight, np.sqrt(2))
                p.bias.data.zero_()

        for i in range(len(self.actor)):
            if type(self.actor[i]) == nn.Linear:
                init.orthogonal_(self.actor[i].weight, 0.01)
                self.actor[i].bias.data.zero_()

    def forward(self, state):
        x = self.feature(state)
        policy = self.actor(x)
        return policy

    def get_action(self, state):
        state = state.unsqueeze(0)
        action = self(state)[0]
        probs = F.softmax(action, dim=0)
        sampled_action_index = torch.multinomial(probs, 1)
        sampled_action = sampled_action_index.item()
        return sampled_action
