"""
Eta in DDIM.

Can be learned but always fixed to 1 during training and 0 during eval right now.

"""

import torch
from model.common.mlp import MLP


class EtaFixed(torch.nn.Module):

    def __init__(
        self,
        base_eta=0.5,
        min_eta=0.1,
        max_eta=1.0,
        **kwargs,
    ):
        super().__init__()
        self.eta_logit = torch.nn.Parameter(torch.ones(1))
        self.min = min_eta
        self.max = max_eta

        # initialize such that eta = base_eta
        self.eta_logit.data = torch.atanh(
            torch.tensor([2 * (base_eta - min_eta) / (max_eta - min_eta) - 1])
        )

    def __call__(self, cond):
        """Match input batch size, but do not depend on input"""
        sample_data = cond["state"] if "state" in cond else cond["rgb"]
        B = len(sample_data)
        device = sample_data.device
        eta_normalized = torch.tanh(self.eta_logit)

        # map to min and max from [-1, 1]
        eta = 0.5 * (eta_normalized + 1) * (self.max - self.min) + self.min
        return torch.full((B, 1), eta.item()).to(device)


class EtaAction(torch.nn.Module):

    def __init__(
        self,
        action_dim,
        base_eta=0.5,
        min_eta=0.1,
        max_eta=1.0,
        **kwargs,
    ):
        super().__init__()
        # initialize such that eta = base_eta
        self.eta_logit = torch.nn.Parameter(
            torch.ones(action_dim)
            * torch.atanh(
                torch.tensor([2 * (base_eta - min_eta) / (max_eta - min_eta) - 1])
            )
        )
        self.min = min_eta
        self.max = max_eta

    def __call__(self, cond):
        """Match input batch size, but do not depend on input"""
        sample_data = cond["state"] if "state" in cond else cond["rgb"]
        B = len(sample_data)
        device = sample_data.device
        eta_normalized = torch.tanh(self.eta_logit)

        # map to min and max from [-1, 1]
        eta = 0.5 * (eta_normalized + 1) * (self.max - self.min) + self.min
        return eta.repeat(B, 1).to(device)


class EtaState(torch.nn.Module):

    def __init__(
        self,
        input_dim,
        mlp_dims,
        activation_type="ReLU",
        out_activation_type="Identity",
        base_eta=0.5,
        min_eta=0.1,
        max_eta=1.0,
        gain=1e-2,
        **kwargs,
    ):
        super().__init__()
        self.base = base_eta
        self.min_res = min_eta - base_eta
        self.max_res = max_eta - base_eta
        self.mlp_res = MLP(
            [input_dim] + mlp_dims + [1],
            activation_type=activation_type,
            out_activation_type=out_activation_type,
        )

        # initialize such that mlp(x) = 0
        for m in self.mlp_res.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_normal_(m.weight, gain=gain)
                m.bias.data.fill_(0)

    def __call__(self, cond):
        if "rgb" in cond:
            raise NotImplementedError(
                "State-based eta not implemented for image-based training!"
            )
        # flatten history
        B = len(cond["state"])
        state = cond["state"].view(B, -1)

        # forward pass
        eta_res = self.mlp_res(state)
        eta_res = torch.tanh(eta_res)  # [-1, 1]
        eta = eta_res + self.base  # [0, 2]
        return torch.clamp(eta, self.min_res + self.base, self.max_res + self.base)


class EtaStateAction(torch.nn.Module):

    def __init__(
        self,
        input_dim,
        mlp_dims,
        action_dim,
        activation_type="ReLU",
        out_activation_type="Identity",
        base_eta=1,
        min_eta=1e-3,
        max_eta=2,
        gain=1e-2,
        **kwargs,
    ):
        super().__init__()
        self.base = base_eta
        self.min_res = min_eta - base_eta
        self.max_res = max_eta - base_eta
        self.mlp_res = MLP(
            [input_dim] + mlp_dims + [action_dim],
            activation_type=activation_type,
            out_activation_type=out_activation_type,
        )

        # initialize such that mlp(x) = 0
        for m in self.mlp_res.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_normal_(m.weight, gain=gain)
                m.bias.data.fill_(0)

    def __call__(self, cond):
        if "rgb" in cond:
            raise NotImplementedError(
                "State-action-based eta not implemented for image-based training!"
            )
        # flatten history
        B = len(cond["state"])
        state = cond["state"].view(B, -1)

        # forward pass
        eta_res = self.mlp_res(state)
        eta_res = torch.tanh(eta_res)  # [-1, 1]
        eta = eta_res + self.base
        return torch.clamp(eta, self.min_res + self.base, self.max_res + self.base)
