import jax
import haiku as hk
from acme.jax import networks as networks_lib

class MLP(hk.Module):
    """ mlp model """
    def __init__(self, widths, name=None):
        super().__init__(name=name)
        self.widths = widths

    def __call__(self, x):
        layers = [hk.Flatten()]
        for i in range(len(self.widths) - 1):
            layers.append(hk.Linear(self.widths[i]))
            layers.append(jax.nn.relu)
        layers.append(hk.Linear(self.widths[-1]))
        self.net = hk.Sequential(layers)
        return self.net(x)


class DQN(hk.Module):
    """ DQN model """
    def __init__(self, dense_layers=2, width=512, output_dim=1, 
                    name=None):
        super().__init__(name=name)
        self.dense_layers = dense_layers
        self.width = width
        self.output_dim=output_dim

        layers = [  
                    hk.Conv2D(32, [8, 8], 4),
                    jax.nn.relu,
                    hk.Conv2D(64, [4, 4], 2),
                    jax.nn.relu,
                    hk.Conv2D(64, [3, 3], 1),
                    jax.nn.relu,
                    hk.Flatten(),
                ]
        for _ in range(self.dense_layers):
            layers.append(hk.Linear(self.width))
            layers.append(jax.nn.relu)
        layers.append(hk.Linear(self.output_dim))
        self.net = hk.Sequential(layers)


    def __call__(self, x):
        return self.net(x)



# class Resnet(hk.Module):

#     def __init__(self, dense_layers=2, width=512,
#                     output_dim=1, name=None):
#         super().__init__(name=name)

#         self.trunk = hk.nets.ResNet(
#                         blocks_per_group = (2,2,2,2),
#                         num_classes = width,
#                         bottleneck = False,
#                         channels_per_group = (64, 128, 256, 512),
#                         use_projection = (False, True, True, True),
#                         logits_config = None,
#                         strides = (1, 2, 2, 2)
#         )
#         layers = []
#         for _ in range(dense_layers):
#             layers.append(hk.Linear(width))
#             layers.append(jax.nn.relu)
#         layers.append(hk.Linear(output_dim))
#         self.top_net = hk.Sequential(layers)
    
#     def __call__(self, x, is_training=True):
#         z = self.trunk(x, is_training)
#         return self.top_net(z)


def build_network(network_type, dummy_obs, out_dim, 
                    width=None, depth=None, prior=False):
    """ function to build network """
    if network_type == 'dqn':
        if prior:
            n_dense = 1
        else:
            n_dense = 2
        network_hk = hk.without_apply_rng(hk.transform(lambda x: 
                                            DQN(dense_layers = n_dense, output_dim=out_dim)
                                        (x)))   
    elif network_type == 'mlp':
        if prior:
            depth = depth-1
        network_hk = hk.without_apply_rng(hk.transform(lambda x: 
                                            MLP(widths=[width]*depth + [out_dim])
                                        (x)))
    
    network = networks_lib.FeedForwardNetwork(
                init=lambda rng: network_hk.init(rng, dummy_obs),
                apply=network_hk.apply)

    return network