import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from typing import Callable, List, Optional


def encode_observation(obs: torch.Tensor, hidden_dim: int) -> torch.Tensor:
    # Assuming obs is a single real number between 0 and 1
    # Discretize the observation into one of the hidden_dim buckets
    discretized = torch.clamp((obs * hidden_dim).long(), 0, hidden_dim - 1)

    # Create a one-hot vector
    one_hot = F.one_hot(discretized, num_classes=hidden_dim).float()
    one_hot *= obs.unsqueeze(-1) * hidden_dim

    return one_hot


class BaseRewardModel(nn.Module):
    network: Callable[[torch.Tensor], torch.Tensor]

    def __init__(
        self,
        *,
        state_dim: int,
        output_dim: int = 1,
        num_layers: int = 4,
        hidden_dim: int = 128,
        use_batchnorm: bool = False,
        use_encoding: bool = False
    ):
        super().__init__()

        if use_encoding:
            assert num_layers == 1, "Encoding must be used with a single layer"

        layers: List[nn.Module] = []
        for layer_index in range(num_layers - 1):
            if layer_index == 0:
                layers.append(nn.Linear(state_dim, hidden_dim))
            else:
                layers.append(nn.Linear(hidden_dim, hidden_dim))
            if layer_index < num_layers - 2:
                if use_batchnorm:
                    layers.append(nn.BatchNorm1d(hidden_dim))
                layers.append(nn.ReLU())

        if num_layers > 1 and not use_encoding:
            self.network = nn.Sequential(*layers)
            self.last_layer = nn.Linear(
                hidden_dim, output_dim
            )  # this is seperated for lexicase stuff
        elif not use_encoding:
            # identity
            self.network = nn.Identity()
            self.last_layer = nn.Linear(
                state_dim, output_dim
            )  # this is seperated for lexicase stuff
        else:
            self.network = lambda x: encode_observation(x, hidden_dim)
            self.last_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.last_layer(F.relu(self.network(state))).squeeze(-1)

    def get_penultimate_layer(self, state: torch.Tensor) -> torch.Tensor:
        return F.relu(self.network(state)).squeeze(-1)

    def preference_logp(
        self, state0: torch.Tensor, state1: torch.Tensor, preferences: torch.Tensor
    ) -> torch.Tensor:
        """
        Return the log probability of the given preference comparisons according to the
        model. If preferences[i] == 0, then state0 is preferred to state1, and vice
        versa.
        """

        reward0 = self.forward(state0)
        reward1 = self.forward(state1)
        reward_diff = reward0 - reward1
        reward_diff[preferences == 1] *= -1
        return -F.softplus(-reward_diff)


class MeanAndVarianceRewardModel(BaseRewardModel):
    def __init__(self, *args, max_std=np.inf, **kwargs):
        super().__init__(*args, output_dim=2, **kwargs)
        self.max_std = max_std

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        mean_and_log_std = self.network(state)
        # return mean_and_log_std
        mean = mean_and_log_std[:, 0]
        log_std = mean_and_log_std[:, 1] - 2
        log_std = log_std.clamp(max=np.log(self.max_std))
        return torch.stack([mean, log_std], dim=1)

    def preference_logp(
        self, state0: torch.Tensor, state1: torch.Tensor, preferences: torch.Tensor
    ) -> torch.Tensor:
        output0 = self.forward(state0)
        output1 = self.forward(state1)
        mean0 = output0[:, 0]
        log_std0 = output0[:, 1]
        mean1 = output1[:, 0]
        log_std1 = output1[:, 1]

        diff_mean = mean0 - mean1
        diff_mean[preferences == 1] *= -1
        var_combined = torch.exp(log_std0) ** 2 + torch.exp(log_std1) ** 2
        # p: torch.Tensor = Normal(0, torch.sqrt(var_combined)).cdf(diff_mean)
        z = diff_mean / torch.sqrt(var_combined)
        # Based on approximation here: https://stats.stackexchange.com/a/452121
        return -F.softplus(-z * np.sqrt(2 * np.pi))
        # logp = torch.log(p.clamp(min=1e-4))
        # return logp


class CategoricalRewardModel(BaseRewardModel):
    comparison_matrix: torch.Tensor

    def __init__(
        self, *args, num_atoms: Optional[int] = None, state_dim: int, **kwargs
    ):
        if num_atoms is None:
            if state_dim == 1:
                num_atoms = 20
            else:
                num_atoms = 8
        super().__init__(
            *args,
            output_dim=num_atoms,
            use_batchnorm=False,
            state_dim=state_dim,
            **kwargs,
        )

        comparison_matrix = torch.empty((num_atoms, num_atoms))
        atom_values = torch.linspace(0, 1, num_atoms)
        comparison_matrix[:] = atom_values[None, :] > atom_values[:, None]
        comparison_matrix[atom_values[None, :] == atom_values[:, None]] = 0.5
        self.register_buffer("comparison_matrix", comparison_matrix)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return F.softmax(self.last_layer(F.relu(self.network(state))), dim=-1)

    def preference_logp(
        self, state0: torch.Tensor, state1: torch.Tensor, preferences: torch.Tensor
    ) -> torch.Tensor:
        dist0 = self.forward(state0)
        dist1 = self.forward(state1)
        prob1 = ((dist0 @ self.comparison_matrix) * dist1).sum(dim=1)
        prob = prob1.clone()
        prob[preferences == 0] = (1 - prob1)[preferences == 0]
        return prob.log()


class ClassifierRewardModel(BaseRewardModel):
    def __init__(self, *args, state_dim, **kwargs):
        super().__init__(*args, state_dim=state_dim * 2, **kwargs)

    def forward(self, state0: torch.Tensor, state1: torch.Tensor) -> torch.Tensor:  # type: ignore
        return self.last_layer(
            F.relu(self.network(torch.cat([state0, state1], dim=-1)))
        )

    def preference_logp(
        self, state0: torch.Tensor, state1: torch.Tensor, preferences: torch.Tensor
    ) -> torch.Tensor:
        """
        Return the log probability of the given preference comparisons according to the
        model. If preferences[i] == 0, then state0 is preferred to state1, and vice
        versa.
        """

        logits = self.forward(state0, state1)
        logits[preferences == 0] *= -1
        return -F.softplus(-logits)
