import dataclasses
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D

from d3rlpy.models.encoders import EncoderFactory

class CustomEncoder(nn.Module):
    def __init__(self, observation_shape, features_dim, activation_fn=None):
        super().__init__()

        if activation_fn is None:
            activation_fn = nn.ReLU

        self.cnn = nn.Sequential(
            nn.Conv2d(
                in_channels=observation_shape[0], out_channels=64, kernel_size=3, stride=2, padding=1
            ),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(torch.rand((1, *observation_shape))).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), activation_fn())

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        assert observations.ndim == 4
        return self.linear(self.cnn(observations))


class CustomEncoderWithAction(nn.Module):
    def __init__(self, observation_shape, features_dim, action_size):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(
                in_channels=observation_shape[0], out_channels=8, kernel_size=3, stride=2, padding=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),

        )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(torch.rand((1, *observation_shape))).shape[1]

        self.cnn_encoding = nn.Sequential(
            nn.Linear(n_flatten, 50),
            nn.Tanh(),
            nn.Linear(50, 100),
        )

        self.linear = nn.Sequential(
            nn.Linear(100+action_size, features_dim),
            nn.Tanh(),
            nn.Linear(features_dim, features_dim),
        )

    def forward(self, observations, action): # action is also given
        x = self.cnn(observations)
        x = self.cnn_encoding(x)
        x = torch.cat([x, action], dim=1)
        return self.linear(x)

@dataclasses.dataclass()
class CustomEncoderFactory(EncoderFactory):
    feature_size: int

    def create(self, observation_shape):
        return CustomEncoder(observation_shape, self.feature_size)

    def create_with_action(self, observation_shape, action_size):
        return CustomEncoderWithAction(observation_shape, self.feature_size, action_size)

    @staticmethod
    def get_type() -> str:
        return "custom"


class QNetwork(nn.Module):
    def __init__(self, observation_shape=(1, 50, 50), n_actions=2):
        super().__init__()
        self.encoder = CustomEncoderWithAction(
            observation_shape=observation_shape,
            features_dim=10,
            action_size=n_actions,
        )

        self.action_embedding = nn.Sequential(
            nn.Linear(n_actions, 20),
            nn.Tanh(),
            nn.Linear(20, 10),
        )

        self.fc = nn.Sequential(
            nn.Linear(10 + 10, 20),
            nn.Tanh(),
            nn.Linear(20, 1),
        )


    def forward(self, observation, action):
        encoded = self.encoder(observation, action)

        action_embedded = self.action_embedding(action)
        x = torch.cat([encoded, action_embedded], dim=1) # skip connection
        return self.fc(x)


class GaussianHead(nn.Module):
    def __init__(self, input_dim, n_actions=6, hidden_dim=128):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.mean_head = nn.Linear(hidden_dim, n_actions)
        self.log_std_head = nn.Linear(hidden_dim, n_actions)

    def forward(self, x):
        x = self.fc(x)
        mean = self.mean_head(x)
        log_std = self.log_std_head(x)

        return mean, log_std

    def get_dist(self, mean, log_std):

        std = F.softplus(log_std) + 1e-5
        cov_matrix = torch.diag_embed(std ** 2)
        return D.MultivariateNormal(mean, covariance_matrix=cov_matrix)

    def log_prob(self, action, mean, log_std):
        dist = self.get_dist(mean, log_std)
        log_probs = dist.log_prob(action)
        return log_probs


class ActionProbabilityModel(nn.Module):
    def __init__(self, n_actions=6, observation_shape=(1, 50, 50), hidden_dim=128):
        super().__init__()

        if len(observation_shape) == 3:
            self.encoder = nn.Sequential(
                nn.Conv2d(1, 8, 3, padding=1), nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(8, 16, 3, padding=1), nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(),
                nn.Flatten()
            )


            with torch.no_grad():
                n_flatten = self.encoder(torch.rand((1, *observation_shape))).shape[1]

        else:
            self.encoder = nn.Sequential(
                nn.Linear(observation_shape[0], 10),
                nn.Tanh(),
                nn.Linear(10, 10)
            )
            n_flatten = 10


        self.ff = nn.Sequential(
            nn.Linear(n_flatten, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, 10)
        )

        self.gmm = GaussianHead(input_dim=10, n_actions=n_actions)

    def forward(self, obs):
        x = self.encoder(obs)
        x = self.ff(x)

        mean, log_std = self.gmm(x)
        return mean, log_std

    def log_prob(self, obs, action):
        means, log_std = self(obs)
        log_prob = self.gmm.log_prob(action, means, log_std)
        return log_prob

    def expected_action(self, obs):
        """Computes the expected action."""
        mean, _= self(obs)
        return mean


class SimpleEncoder(nn.Module):
    def __init__(self, observation_shape, feature_size):
        super().__init__()
        self.feature_size = feature_size
        self.fc1 = nn.Linear(observation_shape[0], feature_size)
        self.fc2 = nn.Linear(feature_size, feature_size)

    def forward(self, x):
        h = torch.relu(self.fc1(x))
        h = torch.relu(self.fc2(h))
        return h

class SimpleEncoderWithAction(nn.Module):
    def __init__(self, observation_shape, action_size, feature_size):
        super().__init__()
        self.feature_size = feature_size
        self.fc1 = nn.Linear(observation_shape[0] + action_size, feature_size)
        self.fc2 = nn.Linear(feature_size, feature_size)

    def forward(self, x, action):
        h = torch.cat([x, action], dim=1)
        h = torch.relu(self.fc1(h))
        h = torch.relu(self.fc2(h))
        return h

@dataclasses.dataclass()
class SimpleEncoderFactory(EncoderFactory):
    feature_size: int

    def create(self, observation_shape):
        return SimpleEncoder(observation_shape, self.feature_size)

    def create_with_action(self, observation_shape, action_size):
        return SimpleEncoderWithAction(observation_shape, action_size, self.feature_size)

    @staticmethod
    def get_type() -> str:
        return "custom"
