import numpy as np
from xuance.common import Sequence, Optional, Union, Callable
from xuance.torch import Module, Tensor
from xuance.torch.utils import torch, nn, cnn_block, mlp_block, ModuleType
from typing import Sequence, Optional, Callable, Union, Any


class Basic_CNN(nn.Module):
    def __init__(self,
                 input_shape: Sequence[int],  # input_shape should be (channels, width)
                 kernels: Sequence[int],
                 strides: Sequence[int],
                 filters: Sequence[int],
                 normalize: Optional[Any] = None,
                 initialize: Optional[Callable[..., Tensor]] = None,
                 activation: Optional[Any] = None,
                 device: Optional[Union[str, int, torch.device]] = None,
                 **kwargs):
        super(Basic_CNN, self).__init__()
        self.input_shape = (input_shape[1], input_shape[0])  # (channels, width)
        self.kernels = kernels
        self.strides = strides
        self.filters = filters
        self.normalize = normalize
        self.initialize = initialize
        self.activation = activation
        self.device = device
        self.pool_dim = 8  # 决定AdaptiveMaxPool1d层的输出维度
        self.output_shapes = {'state': (filters[-1] * self.pool_dim,)}
        self.model = self._create_network()

    def _create_network(self):
        layers = []
        input_shape = self.input_shape
        for k, s, f in zip(self.kernels, self.strides, self.filters):
            cnn, input_shape = cnn_block_1d(input_shape, f, k, s, self.normalize, self.activation, self.initialize,
                                            self.device)
            layers.extend(cnn)
        layers.append(nn.AdaptiveMaxPool1d(self.pool_dim))  # 使用 AdaptiveMaxPool1d 代替 AdaptiveMaxPool2d
        # 使用 AdaptiveMaxPool1d 代替 AdaptiveMaxPool2d 将所有输出压缩到宽度为 self.pool_dim 的结果，即类似于全局池化。
        # 例如，如果输入的张量形状是 (batch_size, channels, width)，经过 AdaptiveMaxPool1d(1) 后，输出的张量形状将变为 (batch_size, channels, 1)，即每个通道提取一个最大值。
        # 由于filter的最后一层是64，即有64通道，最终CNN输出的形状是(64*self.pool_dim, )
        layers.append(nn.Flatten())
        return nn.Sequential(*layers)

    def forward(self, observations: torch.Tensor):
        # observations = observations / 3.0  # Normalize the input
        # 判断数据类型是否是list
        if not isinstance(observations, list):
            tensor_observation = torch.as_tensor(observations, dtype=torch.float32, device=self.device)
        else:
            observations = np.array(observations)
            tensor_observation = torch.as_tensor(observations, dtype=torch.float32, device=self.device)
        tensor_observation = tensor_observation.permute(0, 2, 1)
        # conv1d的输入是(batch_size, channels, width)，obs本来是(batch_size, width, channels)
        # 所以要转置np.transpose(observations, (0, 2, 1)
        return {'state': self.model(tensor_observation)}

# 用于构建1D卷积层的块，类似于原来的cnn_block函数
def cnn_block_1d(input_shape: Sequence[int], out_channels: int, kernel_size: int, stride: int,
                 normalize: Optional[Any], activation: Optional[Any], initialize: Optional[Callable[..., Tensor]],
                 device: Optional[Union[str, int, torch.device]]):
    layers = []
    in_channels = input_shape[0]

    conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, device=device)
    layers.append(conv)

    if normalize:
        layers.append(normalize(out_channels))

    if activation:
        layers.append(activation())

    # 输出宽度的计算 (W_out = (W_in - kernel_size) // stride + 1)
    width_out = (input_shape[1] - kernel_size) // stride + 1
    return layers, (out_channels, width_out)



# process the input observations with stacks of CNN layers
class Basic_CNN0(Module):
    def __init__(self,
                 input_shape: Sequence[int],
                 kernels: Sequence[int],
                 strides: Sequence[int],
                 filters: Sequence[int],
                 normalize: Optional[ModuleType] = None,
                 initialize: Optional[Callable[..., Tensor]] = None,
                 activation: Optional[ModuleType] = None,
                 device: Optional[Union[str, int, torch.device]] = None,
                 **kwargs):
        super(Basic_CNN0, self).__init__()
        self.input_shape = (input_shape[2], input_shape[0], input_shape[1])  # Channels x Height x Width
        self.kernels = kernels
        self.strides = strides
        self.filters = filters
        self.normalize = normalize
        self.initialize = initialize
        self.activation = activation
        self.device = device
        self.output_shapes = {'state': (filters[-1],)}
        self.model = self._create_network()

    def _create_network(self):
        layers = []
        input_shape = self.input_shape
        for k, s, f in zip(self.kernels, self.strides, self.filters):
            cnn, input_shape = cnn_block(input_shape, f, k, s, self.normalize, self.activation, self.initialize,
                                         self.device)
            layers.extend(cnn)
        layers.append(nn.AdaptiveMaxPool2d((1, 1)))
        layers.append(nn.Flatten())
        return nn.Sequential(*layers)

    def forward(self, observations: np.ndarray):
        observations = observations / 255.0
        tensor_observation = torch.as_tensor(np.transpose(observations, (0, 3, 1, 2)), dtype=torch.float32,
                                             device=self.device)
        return {'state': self.model(tensor_observation)}


class AC_CNN_Atari(Module):
    def __init__(self,
                 input_shape: Sequence[int],
                 kernels: Sequence[int],
                 strides: Sequence[int],
                 filters: Sequence[int],
                 normalize: Optional[ModuleType] = None,
                 initialize: Optional[Callable[..., Tensor]] = None,
                 activation: Optional[ModuleType] = None,
                 device: Optional[Union[str, int, torch.device]] = None,
                 fc_hidden_sizes: Sequence[int] = (),
                 **kwargs):
        super(AC_CNN_Atari, self).__init__()
        self.input_shape = (input_shape[2], input_shape[0], input_shape[1])  # Channels x Height x Width
        self.kernels = kernels
        self.strides = strides
        self.filters = filters
        self.normalize = normalize
        self.initialize = initialize
        self.activation = activation
        self.device = device
        self.fc_hidden_sizes = fc_hidden_sizes
        self.output_shapes = {'state': (fc_hidden_sizes[-1],)}
        self.model = self._create_network()

    def _init_layer(self, layer, gain=np.sqrt(2), bias=0.0):
        nn.init.orthogonal_(layer.weight, gain=gain)
        nn.init.constant_(layer.bias, bias)
        return layer

    def _create_network(self):
        layers = []
        input_shape = self.input_shape
        for k, s, f in zip(self.kernels, self.strides, self.filters):
            cnn, input_shape = cnn_block(input_shape, f, k, s, None, self.activation, None, self.device)
            cnn[0] = self._init_layer(cnn[0])
            layers.extend(cnn)
        layers.append(nn.Flatten())
        input_shape = (np.prod(input_shape, dtype=np.int), )
        for h in self.fc_hidden_sizes:
            mlp, input_shape = mlp_block(input_shape[0], h, None, self.activation, None, self.device)
            mlp[0] = self._init_layer(mlp[0])
            layers.extend(mlp)
        return nn.Sequential(*layers)

    def forward(self, observations: np.ndarray):
        observations = observations / 255.0
        tensor_observation = torch.as_tensor(np.transpose(observations, (0, 3, 1, 2)), dtype=torch.float32,
                                             device=self.device)
        return {'state': self.model(tensor_observation)}
