# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import jax
# import flax.linen as fnn
# import jax.numpy as jnp
# from jax import random
# import numpy as np

# ###########################
# # Jax
# ###########################
# def default_init(scale: float = jnp.sqrt(2)):
#     return fnn.initializers.orthogonal(scale)

# def xavier_init():
#     return fnn.initializers.xavier_normal()

# def kaiming_init():
#     return fnn.initializers.kaiming_normal()

# class ResnetStackJax(fnn.Module):
#     num_ch: int
#     num_blocks: int
#     use_max_pooling: bool = True

#     @fnn.compact
#     def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:
#         initializer = fnn.initializers.xavier_uniform()
#         conv_out = fnn.Conv(
#             features=self.num_ch,
#             kernel_size=(3, 3),
#             strides=1,
#             kernel_init=initializer,
#             padding='SAME'
#         )(observations)

#         if self.use_max_pooling:
#             conv_out = fnn.max_pool(
#                 conv_out,
#                 window_shape=(3, 3),
#                 padding='SAME',
#                 strides=(2, 2)
#             )

#         for _ in range(self.num_blocks):
#             block_input = conv_out
#             conv_out = fnn.relu(conv_out)
#             conv_out = fnn.Conv(
#                 features=self.num_ch, kernel_size=(3, 3), strides=1,
#                 padding='SAME',
#                 kernel_init=initializer)(conv_out)

#             conv_out = fnn.relu(conv_out)
#             conv_out = fnn.Conv(
#                 features=self.num_ch, kernel_size=(3, 3), strides=1,
#                 padding='SAME', kernel_init=initializer
#             )(conv_out)
#             conv_out += block_input

#         return conv_out

# class ImpalaEncoderJax(fnn.Module):
#     width: int = 1
#     use_multiplicative_cond: bool = False
#     stack_sizes: tuple = (16, 32, 32)
#     num_blocks: int = 2
#     dropout_rate: float = None

#     def setup(self):
#         stack_sizes = self.stack_sizes
#         self.stack_blocks = [
#             ResnetStackJax(
#                 num_ch=stack_sizes[i] * self.width,
#                 num_blocks=self.num_blocks,
#             )
#             for i in range(len(stack_sizes))

#         ]
#         if self.dropout_rate is not None:
#             self.dropout = fnn.Dropout(rate=self.dropout_rate)

#     @fnn.compact
#     def __call__(self, x, train=True, cond_var=None):

#         x = x.astype(jnp.float32) / 255.0
#         # x = jnp.reshape(x, (*x.shape[:-2], -1))

#         conv_out = x

#         for idx in range(len(self.stack_blocks)):
#             conv_out = self.stack_blocks[idx](conv_out)
#             if self.dropout_rate is not None:
#                 conv_out = self.dropout(conv_out, deterministic=not train)
#             if self.use_multiplicative_cond:
#                 assert cond_var is not None, "Cond var shouldn't be done when using it"
#                 print("Using Multiplicative Cond!")
#                 temp_out = fnn.Dense(conv_out.shape[-1], kernel_init=xavier_init())(cond_var)
#                 x_mult = jnp.expand_dims(jnp.expand_dims(temp_out, 1), 1)
#                 print ('x_mult shape in IMPALA:', x_mult.shape, conv_out.shape)
#                 conv_out = conv_out * x_mult

#         conv_out = fnn.relu(conv_out)
#         # print(conv_out.shape, conv_out.reshape((*x.shape[:-3], -1)).shape)
#         return conv_out.reshape((*x.shape[:-3], -1))

import torch
import torch.nn as nn
import torch.nn.functional as F

###########################
# PyTorch
###########################
def default_init(scale: float = torch.sqrt(torch.tensor(2.0))):
    def init_fn(tensor):
        return nn.init.orthogonal_(tensor, gain=scale)
    return init_fn

def xavier_init():
    return nn.init.xavier_normal_

def kaiming_init():
    return nn.init.kaiming_normal_

class ResnetStackTorch(nn.Module):
    def __init__(self, in_channels: int, num_ch: int, num_blocks: int, use_max_pooling: bool = True):
        super(ResnetStackTorch, self).__init__()
        self.num_ch = num_ch
        self.num_blocks = num_blocks
        self.use_max_pooling = use_max_pooling

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=self.num_ch, kernel_size=3, stride=1, padding=1)
        nn.init.xavier_uniform_(self.conv1.weight)

        if self.use_max_pooling:
            self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.conv_blocks = nn.ModuleList()
        for _ in range(self.num_blocks):
            conv_2 = nn.Conv2d(in_channels=self.num_ch, out_channels=self.num_ch, kernel_size=3, stride=1, padding=1)
            nn.init.xavier_uniform_(conv_2.weight)

            conv_3 = nn.Conv2d(in_channels=self.num_ch, out_channels=self.num_ch, kernel_size=3, stride=1, padding=1)
            nn.init.xavier_uniform_(conv_3.weight)

            self.conv_blocks.append(
                nn.Sequential(
                    nn.ReLU(),
                    conv_2,
                    nn.ReLU(),
                    conv_3
                )
            )

    def forward(self, observations):
        conv_out = self.conv1(observations)

        if self.use_max_pooling:
            conv_out = self.max_pool(conv_out)

        for block in self.conv_blocks:
            block_input = conv_out
            conv_out = block(conv_out)
            conv_out += block_input

        return conv_out

class ImpalaEncoderTorch(nn.Module):
    def __init__(
            self,
            in_channels: int = 3, 
            width: int = 1, 
            use_multiplicative_cond: bool = False,
            stack_sizes: tuple = (16, 32, 32), 
            num_blocks: int = 2,
            dropout_rate: float = None
        ):
        super(ImpalaEncoderTorch, self).__init__()
        self.width = width
        self.use_multiplicative_cond = use_multiplicative_cond
        self.stack_sizes = stack_sizes
        stack_sizes = [in_channels] + list(stack_sizes)
        self.num_blocks = num_blocks
        self.dropout_rate = dropout_rate

        # The first layer should have an input channel size of 3 (for RGB images)
        self.stack_blocks = nn.ModuleList()
        for i in range(len(stack_sizes)-1):
            self.stack_blocks.append(ResnetStackTorch(in_channels=stack_sizes[i] * self.width, num_ch=stack_sizes[i+1] * self.width, num_blocks=self.num_blocks))

        if self.dropout_rate is not None:
            self.dropout = nn.Dropout(p=self.dropout_rate)

    def forward(self, x, train=True, cond_var=None, normalize=False):
        if normalize:
            x = x.float() / 255.0

        conv_out = x.permute(0,3,1,2)

        for idx in range(len(self.stack_blocks)):
            conv_out = self.stack_blocks[idx](conv_out)
            if self.dropout_rate is not None and train:
                conv_out = self.dropout(conv_out)
            if self.use_multiplicative_cond:
                raise NotImplementedError()

        conv_out = F.relu(conv_out)
        conv_out = conv_out.permute(0,2,3,1)
        return conv_out.reshape((*x.shape[:-3], -1))

if __name__ == '__main__':
    import time

    # Instantiate models
    jax_model = ImpalaEncoderJax()
    torch_model = ImpalaEncoderTorch()

    # Generate random input data
    key = random.PRNGKey(0)
    jax_input = random.normal(key, (1, 64, 64, 3))
    torch_input = torch.tensor(np.array(jax_input))

    # Initialize the JAX model's parameters
    params = jax_model.init(key, jax_input, train=True)

    # (First) Forward pass through JAX model
    t0 = time.time()
    jax_output = jax_model.apply(params, jax_input, train=True)
    print(f'jax time: {time.time() - t0}')

    # Forward pass through PyTorch model
    t0 = time.time()
    torch_output = torch_model(torch_input).detach().numpy()
    print(f'torch time: {time.time() - t0}')

    # Compare outputs
    print("JAX output shape:", jax_output.shape)
    print("Torch output shape:", torch_output.shape)

    # Reshape JAX output for comparison
    jax_output = jax_output.reshape((-1,))
    torch_output = torch_output.reshape((-1,))

    # Print the difference between JAX and Torch outputs
    difference = np.abs(jax_output - torch_output).mean()
    print(f"Mean difference between JAX and PyTorch outputs: {difference}")
