import torch
import torch.nn as nn
import torch.nn.functional as F

def default_init(scale: float = torch.sqrt(torch.tensor(2.0))):
    def init_fn(tensor):
        return nn.init.orthogonal_(tensor, gain=scale)
    return init_fn

def xavier_init():
    return nn.init.xavier_normal_

def kaiming_init():
    return nn.init.kaiming_normal_

class ResnetStack(nn.Module):
    def __init__(self, in_channels: int, num_ch: int, num_blocks: int, use_max_pooling: bool = True):
        super(ResnetStack, self).__init__()
        self.num_ch = num_ch
        self.num_blocks = num_blocks
        self.use_max_pooling = use_max_pooling

        self.conv_0 = nn.Conv2d(in_channels=in_channels, out_channels=self.num_ch, kernel_size=3, stride=1, padding=1)
        nn.init.xavier_uniform_(self.conv_0.weight)

        self.conv_blocks = nn.ModuleList()
        for _ in range(self.num_blocks):
            conv_1 = nn.Conv2d(in_channels=self.num_ch, out_channels=self.num_ch, kernel_size=3, stride=1, padding=1)
            nn.init.xavier_uniform_(conv_1.weight)

            conv_2 = nn.Conv2d(in_channels=self.num_ch, out_channels=self.num_ch, kernel_size=3, stride=1, padding=1)
            nn.init.xavier_uniform_(conv_2.weight)

            self.conv_blocks.append(
                nn.Sequential(
                    conv_1,
                    nn.ReLU(),
                    conv_2,
                    nn.ReLU(),
                )
            )
        
    def forward(self, observations):

        conv_out = self.conv_0(observations)
        conv_out = F.relu(conv_out)

        for block in self.conv_blocks:
            block_input = conv_out
            conv_out = block(conv_out)
            conv_out = conv_out + block_input

        return conv_out

class ResnetEncoder(nn.Module):
    def __init__(
            self,
            in_channels: int = 3, 
            stack_sizes: tuple = (16, 32, 32), 
            num_blocks: int = 2,
            dropout_rate: float = None
        ):

        super(ResnetEncoder, self).__init__()
        self.stack_sizes = stack_sizes
        stack_sizes = [in_channels] + list(stack_sizes)
        self.num_blocks = num_blocks
        self.dropout_rate = dropout_rate

        # The first layer should have an input channel size of 3 (for RGB images)
        self.stack_blocks = nn.ModuleList()
        for i in range(len(stack_sizes)-1):
            self.stack_blocks.append(ResnetStack(in_channels=stack_sizes[i], num_ch=stack_sizes[i+1], num_blocks=self.num_blocks))
        
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        if self.dropout_rate is not None:
            self.dropout = nn.Dropout(p=self.dropout_rate)
        
    def forward(self, x, train=True):

        conv_out = x.permute(0,3,1,2)

        for idx in range(len(self.stack_blocks)):
            conv_out = self.stack_blocks[idx](conv_out)
            conv_out = self.max_pool(conv_out)
            if self.dropout_rate is not None and train:
                conv_out = self.dropout(conv_out)
        
        conv_out = conv_out.permute(0,2,3,1)
        return conv_out.reshape((*x.shape[:-3], -1))

if __name__ == '__main__':

    import time
    import random
    import numpy as np

    encoder = ResnetEncoder(3,(16,32,32),2)

    # Generate random input data
    torch_input = torch.randn((1, 64, 64, 3))

    # Forward pass through PyTorch model
    t0 = time.time()
    torch_output = encoder(torch_input).detach().numpy()
    print(f'torch time: {time.time() - t0}')

    # Compare outputs
    print("Torch output shape:", torch_output.shape)

    # Reshape JAX output for comparison
    torch_output = torch_output.reshape((-1,))
