import torch
import torch.nn as nn


class QNetwork(nn.Module):
    def __init__(self, env, args):
        super().__init__()
        layers = [
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, args.hidden_dim),
            nn.ReLU(),
        ]
        for _ in range(args.linear_layers):
            layers += [nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU()]

        layers += [nn.Linear(args.hidden_dim, env.single_action_space.n)]
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x / 255.0)


class EnsembleQNetwork(nn.Module):
    def __init__(self, env, args):
        super().__init__()
        self.num_ensemble = args.num_ensemble

        for i in range(args.num_ensemble):
            setattr(self, f'network_{i}', QNetwork(env, args))

    def forward(self, x, idx=None):
        if idx is None:
            # Return mean Q-values across the ensemble
            qs = []
            for i in range(self.num_ensemble):
                net = getattr(self, f'network_{i}')
                q = net(x)
                qs.append(q)
            mean_q = torch.stack(qs, dim=1).mean(dim=1)
            return mean_q
        else:
            # Return Q-values from a specific network in the ensemble
            net = getattr(self, f'network_{idx}')
            return net(x)

    def get_all(self, x):
        qs = []
        for i in range(self.num_ensemble):
            net = getattr(self, f'network_{i}')
            q = net(x)
            qs.append(q)
        return qs


class TreeQNetwork(nn.Module):
    def __init__(self, env, args):
        super().__init__()
        self.num_heads = args.num_ensemble
        self.blocks_in_head = args.blocks_in_head

        q_net = QNetwork(env, args)
        layers = list(q_net.network.children())
        blocks = make_into_blocks(layers)
        self.total_blocks = len(blocks)
        assert 0 < self.blocks_in_head < self.total_blocks, "blocks_in_head must be between 0 and total blocks"
        self.blocks_shared = self.total_blocks - self.blocks_in_head
        backbone_blocks = blocks[:self.blocks_shared]
        backbone_layers = []
        for block in backbone_blocks:
            for layer in block:
                backbone_layers.append(layer)
        self.backbone = nn.Sequential(*backbone_layers)

        for head in range(self.num_heads):
            q_net = QNetwork(env, args)  # re-initialize to get fresh layers
            layers = list(q_net.network.children())
            blocks = make_into_blocks(layers)
            head_blocks = blocks[self.blocks_shared:]
            head_layers = []
            for block in head_blocks:
                for layer in block:
                    head_layers.append(layer)
            setattr(self, f'head_{head}', nn.Sequential(*head_layers))

    def forward(self, x, idx=None):
        x = self.backbone(x / 255.0)
        if idx is None:
            # Return mean Q-values across the heads
            qs = []
            for head in range(self.num_heads):
                head_net = getattr(self, f'head_{head}')
                q = head_net(x)
                qs.append(q)
            mean_q = torch.stack(qs, dim=1).mean(dim=1)
            return mean_q
        else:
            # Return Q-values from a specific head
            head_net = getattr(self, f'head_{idx}')
            return head_net(x)

    def get_all(self, x):
        x = self.backbone(x / 255.0)
        qs = []
        for head in range(self.num_heads):
            head_net = getattr(self, f'head_{head}')
            q = head_net(x)
            qs.append(q)
        return qs


def make_into_blocks(layers):
    # group into blocks of 2 (Conv+ReLU or Linear+ReLU) or 3 (Conv+ReLU+Flatten) or 1 (Linear, the last layer)
    blocks = []
    i = 0
    while i < len(layers):
        if isinstance(layers[i], nn.Conv2d):
            if isinstance(layers[i + 1], nn.ReLU):
                if isinstance(layers[i + 2], nn.Flatten):
                    blocks.append((
                        layers[i],
                        layers[i + 1],
                        layers[i + 2]
                    ))
                    i += 3
                else:
                    blocks.append((
                        layers[i],
                        layers[i + 1]
                    ))
                    i += 2
            else:
                raise ValueError("Unexpected layer after Conv2d")
        elif isinstance(layers[i], nn.Linear):
            if i + 1 < len(layers) and isinstance(layers[i + 1], nn.ReLU):
                blocks.append((
                    layers[i],
                    layers[i + 1]
                ))
                i += 2
            elif i + 1 == len(layers):
                blocks.append((
                    layers[i],
                ))
                i += 1
            else:
                raise ValueError("Unexpected layer after Linear")
        else:
            raise ValueError("Unexpected layer type in backbone")
    return blocks
