import torch
from torch.nn import functional as F
from torch import nn
from torch.distributions import Categorical
from config import DEVICE, DTYPE


def sample_actions(q_values: torch.Tensor, eps: float):
    probs = torch.zeros_like(q_values)
    probs.scatter_(-1, q_values.argmax(-1, keepdim=True), 1-eps)
    probs += eps / probs.shape[-1]
    dist = Categorical(probs=probs)
    return dist.sample()


class QNetworkFC(nn.Module):
    def __init__(self, inp_dim, hid_dim, out_dim, n_hid_layers=0):
        super(QNetworkFC, self).__init__()
        layers = [nn.Linear(inp_dim, hid_dim), nn.ReLU()]
        for _ in range(n_hid_layers):
            layers.append(nn.Linear(hid_dim, hid_dim))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hid_dim, out_dim))
        self.fc = nn.Sequential(*layers)

        self.to(DEVICE)
        self.to(DTYPE)

    def forward(self, x: torch.Tensor, sample=False, eps=0.):
        q_values = self.fc(x)
        if sample:
            return q_values, sample_actions(q_values, eps)
        return q_values


class CNNWrapper(nn.Module):
    stride = 1
    padding = 1

    def __init__(self, inp_channels, out_channels=32, kernel_size=4, normalize=True):
        super(CNNWrapper, self).__init__()
        self.out_channels = out_channels
        self.conv = nn.Conv2d(inp_channels, self.out_channels, kernel_size, self.stride, self.padding)
        self.normalize = normalize

        self.to(DEVICE)
        self.to(DTYPE)

    def forward(self, x: torch.Tensor):
        if self.normalize:
            x = x / 255
        x = x.permute(0, 3, 1, 2)  # (B, H, W, C) -> (B, C, H, W)
        x = F.relu(self.conv(x))
        return x

    def get_out_dim(self, inp_height, inp_width):
        return self.out_channels * (inp_height // self.stride - 1) * (inp_width // self.stride - 1)


class QHyperNetworkCNN(nn.Module):
    def __init__(self, inp_size, inp_dim, hid_dim, out_dim):
        super().__init__()
        self.inp_dim, self.hid_dim, self.out_dim = inp_dim, hid_dim, out_dim
        inp_height, inp_width, inp_channels = inp_size

        self.hyper_conv = CNNWrapper(inp_channels, out_channels=16)
        fc_inp_dim = self.hyper_conv.get_out_dim(inp_height, inp_width)

        self.hyper_w1 = nn.Sequential(nn.Linear(fc_inp_dim, hid_dim),
                                      nn.ReLU(),
                                      nn.Linear(hid_dim, inp_dim * hid_dim))
        self.hyper_w2 = nn.Sequential(nn.Linear(fc_inp_dim, hid_dim),
                                      nn.ReLU(),
                                      nn.Linear(hid_dim, hid_dim * out_dim))
        self.hyper_b1 = nn.Sequential(nn.Linear(fc_inp_dim, hid_dim),
                                      nn.ReLU(),
                                      nn.Linear(hid_dim, hid_dim))

        def init_weights(m):
            if type(m) == nn.Linear:
                nn.init.normal_(m.weight, 0, 1e-3)
                nn.init.constant_(m.bias, 0)
        self.hyper_w1.apply(init_weights)
        self.hyper_w2.apply(init_weights)
        self.hyper_b1.apply(init_weights)

        self.to(DEVICE)
        self.to(DTYPE)

    def forward(self, x: torch.Tensor, s: torch.Tensor):
        """
        :param x: Tensor (B, N, inp_dim)
        :param s: Tensor (B, H, W, C)
        :return: Q-values - Tensor (B, N, A)
        """
        s = self.hyper_conv(s).flatten(1, 3)  # B, -1

        w1 = self.hyper_w1(s).view(-1, self.inp_dim, self.hid_dim)  # B, I, H
        b1 = self.hyper_b1(s).view(-1, 1, self.hid_dim)  # B, 1, H
        x = F.elu(torch.bmm(x, w1) + b1)  # (B, N, I) x (B, I, H) + (B, 1, H) = (B, N, H)

        w2 = self.hyper_w2(s).view(-1, self.hid_dim, self.out_dim)  # B, H, A
        q_values = torch.bmm(x, w2)  # (B, N, H) x (B, H, A) = (B, N, A)

        return q_values


class QHyperNetworkCNNShared(nn.Module):
    def __init__(self, inp_size, hid_dim, out_dim, global_size, n_agents):
        inp_height, inp_width, inp_channels = inp_size
        self.n_agents = n_agents
        super(QHyperNetworkCNNShared, self).__init__()
        self.conv = CNNWrapper(inp_channels)
        fc_inp_dim = self.conv.get_out_dim(inp_height, inp_width)
        self.dim_red_layer = nn.Linear(fc_inp_dim, hid_dim * n_agents)
        self.q_network = QHyperNetworkCNN(global_size, hid_dim * n_agents, hid_dim, out_dim)

        self.to(DEVICE)
        self.to(DTYPE)

    def forward(self, x: torch.Tensor, s: torch.Tensor, sample=False, eps=0.):
        n_samples, n_agents, state_shape = x.shape[0], x.shape[1], x.shape[2:]

        x = x.view(-1, *state_shape)  # (B, N, H, W, C) -> (B*N, H, W, C)
        x = self.conv(x).flatten(1, 3)  # (B*N, -1)
        x = self.dim_red_layer(x).view(n_samples, n_agents, -1)  # (B, N, hid_dim)
        q_values = self.q_network(x, s)  # B, N, A

        if sample:
            return q_values, sample_actions(q_values, eps)
        return q_values


class QHyperNetworkCNNActionShared(nn.Module):
    def __init__(self, inp_size, hid_dim, out_dim, global_size, n_agents):
        inp_height, inp_width, inp_channels = inp_size
        self.n_agents = n_agents
        super(QHyperNetworkCNNActionShared, self).__init__()
        self.conv = CNNWrapper(inp_channels)
        fc_inp_dim = self.conv.get_out_dim(inp_height, inp_width)
        self.dim_red_layer = nn.Linear(fc_inp_dim, hid_dim)
        hid_dim_per = hid_dim // n_agents
        self.action_embed = nn.Sequential(nn.Linear(1, hid_dim_per), nn.ReLU(), nn.Linear(hid_dim_per, hid_dim_per))
        self.q_network = QHyperNetworkCNN(global_size, hid_dim + hid_dim_per * n_agents * 2, hid_dim, out_dim)

        self.to(DEVICE)
        self.to(DTYPE)

    def forward(self, x: torch.Tensor, s: torch.Tensor, a_p: torch.Tensor, a: torch.Tensor, sample=False, eps=0.):
        n_samples, n_agents, state_shape = x.shape[0], x.shape[1], x.shape[2:]

        x = x.view(-1, *state_shape)  # (B, N, H, W, C) -> (B*N, H, W, C)
        x = self.conv(x).flatten(1, 3)  # (B*N, -1)
        x = self.dim_red_layer(x).view(n_samples, n_agents, -1)  # (B, N, hid_dim)

        a_p = a_p.to(x.dtype)
        if a_p.ndim == 2:
            a_p = a_p.unsqueeze(-1)

        a_p = self.action_embed(a_p)  # (B, N, hid_dim_per)
        a_p = a_p.unsqueeze(1).repeat_interleave(n_agents, dim=1)  # (B, N, N, hid_dim_per)
        a_p = a_p.view(*a_p.shape[:2], -1)  # (B, N, N*hid_dim_per)

        a = a.to(x.dtype)
        if a.ndim == 2:
            a = a.unsqueeze(-1)

        a = self.action_embed(a)  # (B, N, hid_dim_per)
        a = a.unsqueeze(1).repeat_interleave(n_agents, dim=1)  # (B, N, N, hid_dim_per)
        mask = torch.eye(n_agents, n_agents, dtype=a.dtype, device=a.device).repeat(a.shape[0], 1, 1)
        a *= (1 - mask).unsqueeze(-1)
        a = a.view(*a.shape[:2], -1)  # (B, N, N*hid_dim_per)

        x = torch.cat([x, a_p, a], dim=-1)  # (B, N, hid_dim + 2*N*hid_dim_per)
        q_values = self.q_network(x, s)  # B, N, A

        if sample:
            return q_values, sample_actions(q_values, eps)
        return q_values


class QNetworkCNN(nn.Module):
    def __init__(self, inp_size, hid_dim, out_dim, n_hid_layers, **conv_params):
        inp_height, inp_width, inp_channels = inp_size
        super(QNetworkCNN, self).__init__()
        self.conv = CNNWrapper(inp_channels, **conv_params)
        fc_inp_dim = self.conv.get_out_dim(inp_height, inp_width)
        self.q_network = QNetworkFC(fc_inp_dim, hid_dim, out_dim, n_hid_layers)

        self.to(DEVICE)
        self.to(DTYPE)

    def forward(self, x: torch.Tensor):
        x = self.conv(x).flatten(1, 3)
        return self.q_network(x)


class QNetworkCNNUnshared(nn.Module):
    def __init__(self, inp_size, hid_dim, out_dim, n_hid_layers, n_agents, **conv_params):
        super().__init__()
        self.q_networks = nn.ModuleList([QNetworkCNN(inp_size, hid_dim, out_dim, n_hid_layers, **conv_params)
                                         for _ in range(n_agents)])

    def forward(self, x: torch.Tensor, sample=False, eps=0.):
        q_values = []
        for i, network in enumerate(self.q_networks):
            q_values.append(network(x[:, i]))
        q_values = torch.stack(q_values, 1)

        if sample:
            return q_values, sample_actions(q_values, eps)
        return q_values


class QNetworkCNNShared(QNetworkCNN):
    def __init__(self, inp_size, hid_dim, out_dim, n_hid_layers, **conv_params):
        super(QNetworkCNNShared, self).__init__(inp_size, hid_dim, out_dim, n_hid_layers, **conv_params)

    def forward(self, x: torch.Tensor, sample=False, eps=0.):
        n_samples, n_agents, state_shape = x.shape[0], x.shape[1], x.shape[2:]
        x = x.view(-1, *state_shape)

        q_values = super(QNetworkCNNShared, self).forward(x)  # B * N, A
        q_values = q_values.view(n_samples, n_agents, -1)  # B, N, A

        if sample:
            return q_values, sample_actions(q_values, eps)
        return q_values


class QNetworkCNNAction(nn.Module):
    def __init__(self, inp_size, hid_dim, out_dim, n_hid_layers, **conv_params):
        inp_height, inp_width, inp_channels = inp_size
        super(QNetworkCNNAction, self).__init__()
        self.conv = CNNWrapper(inp_channels, **conv_params)
        fc_inp_dim = self.conv.get_out_dim(inp_height, inp_width)
        self.dim_red_layer = nn.Linear(fc_inp_dim, hid_dim)
        self.action_embed = nn.Sequential(nn.Linear(1, hid_dim), nn.ReLU(), nn.Linear(hid_dim, hid_dim))
        self.q_network = QNetworkFC(hid_dim * 2, hid_dim, out_dim, n_hid_layers)

        self.to(DEVICE)
        self.to(DTYPE)

    def forward(self, x: torch.Tensor, a: torch.Tensor):
        x = self.conv(x).flatten(1, 3)
        x = self.dim_red_layer(x)

        a = a.to(x.dtype)
        if a.ndim == 1:
            a = a.unsqueeze(-1)
        a = self.action_embed(a)

        x = torch.cat([x, a], dim=-1)
        return self.q_network(x)


class QNetworkCNNActionUnshared(nn.Module):
    def __init__(self, inp_size, hid_dim, out_dim, n_hid_layers, n_agents, **conv_params):
        super(QNetworkCNNActionUnshared, self).__init__()
        self.q_networks = nn.ModuleList([QNetworkCNNAction(inp_size, hid_dim, out_dim, n_hid_layers, **conv_params)
                                         for _ in range(n_agents)])

    def forward(self, x: torch.Tensor, a: torch.Tensor, sample=False, eps=0.):
        q_values = []
        for i, network in enumerate(self.q_networks):
            q_values.append(network(x[:, i], a[:, i]))
        q_values = torch.stack(q_values, 1)

        if sample:
            return q_values, sample_actions(q_values, eps)
        return q_values


class QNetworkCNNSusShared(nn.Module):
    def __init__(self, inp_size, hid_dim, out_dim, n_hid_layers, n_agents, **conv_params):
        inp_height, inp_width, inp_channels = inp_size
        self.n_agents = n_agents
        super(QNetworkCNNSusShared, self).__init__()
        self.conv = CNNWrapper(inp_channels, **conv_params)
        fc_inp_dim = self.conv.get_out_dim(inp_height, inp_width)
        self.dim_red_layer = nn.Linear(fc_inp_dim, hid_dim)
        hid_dim_per = hid_dim // n_agents
        self.sus_embed = nn.Linear(1, hid_dim_per)
        self.q_network = QNetworkFC(hid_dim + hid_dim_per * n_agents, hid_dim, out_dim, n_hid_layers)

        self.to(DEVICE)
        self.to(DTYPE)

    def forward(self, x: torch.Tensor, sus: torch.Tensor = None, sample=False, eps=0.):
        n_samples, n_agents, state_shape = x.shape[0], x.shape[1], x.shape[2:]

        x = x.view(-1, *state_shape)  # (B, N, H, W, C) -> (B*N, H, W, C)
        x = self.conv(x).flatten(1, 3)  # (B*N, -1)
        x = self.dim_red_layer(x).view(n_samples, n_agents, -1)  # (B, N, hid_dim)

        if sus is None:
            sus = torch.zeros((x.shape[0], x.shape[1], 1), dtype=x.dtype, device=x.device)
        else:
            sus = sus.to(x.dtype)
            if sus.ndim == 2:
                sus = sus.unsqueeze(-1)

        sus = self.sus_embed(sus)  # (B, N, hid_dim_per)
        sus = sus.unsqueeze(1).repeat_interleave(n_agents, dim=1)  # (B, N, N, hid_dim_per)
        mask = torch.eye(n_agents, n_agents, dtype=sus.dtype, device=sus.device).repeat(sus.shape[0], 1, 1)
        sus *= (1 - mask).unsqueeze(-1)
        sus = sus.view(*sus.shape[:2], -1)  # (B, N, N*hid_dim_per)

        x = torch.cat([x, sus], dim=-1)  # (B, N, hid_dim + N*hid_dim_per)
        q_values = self.q_network(x)  # B, N, A

        if sample:
            return q_values, sample_actions(q_values, eps)
        return q_values


class ClassifierNetworkFC(nn.Module):
    def __init__(self, inp_dim, hid_dim, n_actions, n_outcomes, n_hid_layers=0):
        super(ClassifierNetworkFC, self).__init__()
        layers = [nn.Linear(inp_dim, hid_dim), nn.ReLU()]
        for _ in range(n_hid_layers):
            layers.append(nn.Linear(hid_dim, hid_dim))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hid_dim, n_actions * n_outcomes))
        self.fc = nn.Sequential(*layers)
        self.n_actions, self.n_outcomes = n_actions, n_outcomes

        self.to(DEVICE)
        self.to(DTYPE)

    def forward(self, x: torch.Tensor):
        logits = self.fc(x).view(x.shape[0], self.n_actions, self.n_outcomes)
        probs = F.softmax(logits, -1)
        return logits, probs
