import torch
import torch.nn as nn
import gym
import numpy as np
from gym import spaces
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from countbased.models.idm import NatureCNN

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class CustomCNN(nn.Module):
    """
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        torchis corresponds to torche number of unit for torche last layer.
    """
    def __init__(self, subspace, features_dim: int = 256):
        super(CustomCNN, self).__init__()
        self.cnn = nn.Sequential(
            layer_init(nn.Conv2d(subspace.shape[0], 32, kernel_size=3, stride=2, padding=1)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)),
            nn.ReLU(),
            nn.Flatten()
        )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(
                torch.as_tensor(np.random.randn(1,subspace.shape[0],32,32)).float()
            ).shape[1]

        self.linear = nn.Sequential(layer_init(nn.Linear(n_flatten, features_dim)), nn.ReLU())

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.linear(self.cnn(observations))

class CustomAugmentedExtractorCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Dict):
        super().__init__(observation_space, features_dim=1)

        extractors = {}
        total_concat_size = 0
        
        for key, subspace in observation_space.spaces.items():
            if key == "heatmap_and_pos":
                extractors[key] = CustomCNN(subspace, features_dim=256)
                total_concat_size += 256
            else:
                input_shape_flatten = np.array(subspace.shape).prod()
                
                # this is in the case of partial observability of obs of size (1,5,5)
                if input_shape_flatten < 100:
                    extractors[key] = nn.Flatten()
                    hidden_dim = 25

                # this is in the case of (1,32,32) images
                else:
                    hidden_dim = 256
                    extractors[key] = CustomCNN(subspace, features_dim=hidden_dim)
                
                total_concat_size += hidden_dim
                
        self.extractors = nn.ModuleDict(extractors)
        self._features_dim = total_concat_size

    def forward(self, observations) -> torch.Tensor:
        encoded_tensor_list = []

        # self.extractors contain nn.Modules torchat do all torche processing.
        for key, extractor in self.extractors.items():
            encoded_tensor_list.append(extractor(observations[key]))

        # Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension.
        return torch.cat(encoded_tensor_list, dim=1)

class CustomAugmentedExtractorMLP(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Dict):
        super().__init__(observation_space, features_dim=1)

        extractors = {}

        total_concat_size = 0
        for key, subspace in observation_space.spaces.items():
            if key == "pos":
                extractors[key] = nn.Sequential(
                    nn.Linear(subspace.shape[0], 32),
                    nn.ReLU(),
                )
                total_concat_size += 32
            elif key == "heatmap":
                input_shape = subspace.shape[0]
                hidden_dim = input_shape // 4

                extractors[key] = nn.Sequential(
                    nn.Linear(input_shape, hidden_dim),
                    nn.ReLU(),
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.ReLU(),
                )
                total_concat_size += hidden_dim
            else:
                input_shape = subspace.shape[0]
                
                if input_shape < 100:
                    hidden_dim = 64
                else:
                    hidden_dim = input_shape // 4

                extractors[key] = nn.Sequential(
                    nn.Linear(input_shape, hidden_dim),
                    nn.ReLU(),
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.ReLU(),
                )
                total_concat_size += hidden_dim
                
        self.extractors = nn.ModuleDict(extractors)
        self._features_dim = total_concat_size

    def forward(self, observations) -> torch.Tensor:
        encoded_tensor_list = []

        # self.extractors contain nn.Modules torchat do all torche processing.
        for key, extractor in self.extractors.items():
            encoded_tensor_list.append(extractor(observations[key]))

        # Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension.
        return torch.cat(encoded_tensor_list, dim=1)
    
class CustomAugmentedExtractorCNNGodot(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Dict):
        super().__init__(observation_space, features_dim=1)

        extractors = {}
        total_concat_size = 0
        
        for key, subspace in observation_space.spaces.items():
            if key == "heatmap":
                extractors[key] = CustomCNN(subspace, features_dim=256)
                total_concat_size += 256
            elif key == "vector_obs":
                input_shape_flatten = np.array(subspace.shape).prod()
                extractors[key] = nn.Flatten()
                total_concat_size += input_shape_flatten
            else:
                input_shape_flatten = np.array(subspace.shape).prod()
                
                hidden_dim = 64
                extractors[key] = nn.Sequential(
                    nn.Flatten(),
                    nn.Linear(input_shape_flatten, hidden_dim),
                    nn.ReLU(),
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.ReLU(),
                )
                
                total_concat_size += hidden_dim
                
        self.extractors = nn.ModuleDict(extractors)
        self._features_dim = total_concat_size

    def forward(self, observations) -> torch.Tensor:
        encoded_tensor_list = []

        # self.extractors contain nn.Modules torchat do all torche processing.
        for key, extractor in self.extractors.items():
            encoded_tensor_list.append(extractor(observations[key]))

        # Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension.
        return torch.cat(encoded_tensor_list, dim=1)
    

class CustomEllipsoidCNN(nn.Module):
    def __init__(self, hidden_dim = 512):
        super(CustomEllipsoidCNN, self).__init__()

        modules = []
        in_channels = 1
        out_channels = 8
        out_size = hidden_dim

        while out_size != 4:     
            modules.append(nn.Conv2d(in_channels, out_channels, 3, 1, padding=1))
            modules.append(nn.BatchNorm2d(out_channels))
            modules.append(nn.ReLU())
            modules.append(nn.MaxPool2d(2,2))

            out_size = out_size // 2
            in_channels = out_channels
            out_channels = np.clip(out_channels * 2, a_min=out_channels * 2, a_max=64)

        modules.append(nn.Flatten())

        self.cnn = nn.Sequential(*modules)

        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(
                torch.as_tensor(np.random.randn(1,1,hidden_dim,hidden_dim)).float()
            ).shape[1]

        self.linear = nn.Sequential(nn.Flatten(), layer_init(nn.Linear(n_flatten, 512)), nn.ReLU())

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.linear(self.cnn(observations))

class CustomAugmentedExtractorNatureCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Dict, hidden_dim=512):
        super().__init__(observation_space, features_dim=1)

        extractors = {}
        total_concat_size = 0
        
        for key, subspace in observation_space.spaces.items():
            if key != "ellipsoid":
                extractors[key] = NatureCNN(hidden_dim=512)
                total_concat_size += 512
            else:
                extractors[key] = CustomEllipsoidCNN(hidden_dim=hidden_dim)
                total_concat_size += 512
                
        self.extractors = nn.ModuleDict(extractors)
        self._features_dim = total_concat_size

    def forward(self, observations) -> torch.Tensor:
        encoded_tensor_list = []

        # self.extractors contain nn.Modules torchat do all torche processing.
        for key, extractor in self.extractors.items():
            encoded_tensor_list.append(extractor(observations[key]))

        # Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension.
        return torch.cat(encoded_tensor_list, dim=1)