import torch
import torch.nn as nn
import torch.nn.functional as F

from .tools import symlog, symexp, weight_init
from .networks import TransposeCNN

from typing import Callable

class Output:
    def loss(self, target):
        raise NotImplementedError

    def pred(self):
        # return the statistics
        raise NotImplementedError

    def sample(self):
        # generate a sample from the distribution
        raise NotImplementedError

    @property
    def mode(self):
        # return the most likely sample
        raise NotImplementedError

class MSE(Output):
    def __init__(self, mean, squash=None):
        self.mean = mean
        self.squash = squash or (lambda x: x)

    def loss(self, target):
        assert target.shape == self.mean.shape, f"Target shape {target.shape} does not match mean shape {self.mean.shape}"
        return F.mse_loss(self.mean, self.squash(target).detach(), reduction="none")

    def pred(self):
        return self.mean

class SymlogMSE(Output):
    def __init__(self, mean):
        self.mean = mean

    def loss(self, target):
        assert target.shape == self.mean.shape, f"Target shape {target.shape} does not match mean shape {self.mean.shape}"
        return F.mse_loss(self.mean, symlog(target).detach(), reduction="none")

    def pred(self):
        return symexp(self.mean)

class SymexpMSE(Output):
    def __init__(self, mean):
        self.mean = mean

    def loss(self, target):
        assert target.shape == self.mean.shape, f"Target shape {target.shape} does not match mean shape {self.mean.shape}"
        return F.mse_loss(symexp(self.mean), target.detach(), reduction="none")

    def pred(self):
        return symexp(self.mean)

class Binary(Output):
    def __init__(self, logits: torch.Tensor):
        self.logits = logits
        self.probs = torch.sigmoid(logits)

    def loss(self, target):
        return F.binary_cross_entropy_with_logits(self.logits, target.detach(), reduction="none")

    def pred(self):
        return self.probs

    def sample(self):
        return torch.bernoulli(self.probs)

    @property
    def mode(self):
        return (self.logits > 0).float()

class TwoHot(Output):
    def __init__(self, logits: torch.Tensor, bins: torch.Tensor):
        self.logits = logits
        self.bins = bins
        self.num_bins = bins.shape[-1]

    def loss(self, target):
        target = target.detach()
        below = (self.bins <= target).sum(-1, keepdim=True) - 1
        above = self.num_bins - (self.bins > target).sum(-1, keepdim=True)
        below = torch.clamp(below, min=0, max=self.num_bins - 1)
        above = torch.clamp(above, min=0, max=self.num_bins - 1)
        equal = (below == above)
        dist_to_below = torch.where(equal, 1, torch.abs(self.bins[below] - target))
        dist_to_above = torch.where(equal, 1, torch.abs(self.bins[above] - target))
        total = dist_to_below + dist_to_above
        weight_below = dist_to_above / total
        weight_above = dist_to_below / total

        # calculate loss
        target = (
            F.one_hot(below.squeeze(-1), num_classes=self.num_bins) * weight_below +
            F.one_hot(above.squeeze(-1), num_classes=self.num_bins) * weight_above
        )

        assert self.logits.shape == target.shape, (self.logits.shape, target.shape)
        logprobs = F.log_softmax(self.logits, dim=-1)
        loss = - (target * logprobs).sum(-1, keepdim=True)
        return loss

    def pred(self):
        probs = F.softmax(self.logits, dim=-1)
        return (probs * self.bins).sum(-1, keepdim=True)

class Head(nn.Module):
    def mse(self, x: torch.Tensor, squash: Callable | None = None) -> Output:
        return MSE(x, squash)

    def symlog_mse(self, x: torch.Tensor) -> Output:
        return SymlogMSE(x)

    def symexp_mse(self, x: torch.Tensor) -> Output:
        return SymexpMSE(x)

    def symexp_twohot(self, logits: torch.Tensor) -> Output:
        return TwoHot(logits, self.bins)

    def binary(self, x: torch.Tensor) -> Output:
        return Binary(x)

class MLPHead(Head):
    def __init__(
        self,
        output: str,
        in_dim: int,
        hidden_dim: int,
        hidden_layers: int,
        out_dim: int,
        act: str = "SiLU",
        use_layernorm: bool = True,
        use_symlog: bool = False,
        out_scale: float = 1.0,
        device: torch.device = torch.device("cpu"),
    ):
        super().__init__()
        self.output = output
        self.use_symlog = use_symlog

        layers = []
        for _ in range(hidden_layers):
            layers.append(nn.Linear(in_dim, hidden_dim, device=device))
            if use_layernorm:
                layers.append(nn.LayerNorm(hidden_dim, device=device))
            layers.append(getattr(nn, act)())
            in_dim = hidden_dim
        self._mlp = nn.Sequential(*layers)
        weight_init(self._mlp)

        # output layer
        if output == "symexp_twohot":
            assert out_dim == 1
            self.num_bins = 255
            bins = torch.linspace(start=-20, end=20, steps=self.num_bins, device=device)
            self.bins = symexp(bins)
            self._out = nn.Linear(in_dim, self.num_bins, device=device)
        else:
            self._out = nn.Linear(in_dim, out_dim, device=device)
        weight_init(self._out, scale=out_scale)

    def __call__(self, x: torch.Tensor) -> Output:
        if self.use_symlog:
            x = symlog(x)
        x = self._mlp(x)
        x = self._out(x)
        x = getattr(self, self.output)(x)
        return x

class FigureHead(Head):
    def __init__(
        self,
        in_dim: int,
        out_shape: tuple[int, int, int],
        depth: int,
        mults: list[int],
        kernel: int,
        act: str = "SiLU",
        use_layernorm: bool = True,
        device: torch.device = torch.device("cpu"),
    ):
        super().__init__()
        self._transpose_cnn = TransposeCNN(
            in_dim=in_dim,
            out_shape=out_shape,
            depth=depth,
            mults=mults,
            kernel=kernel,
            act=act,
            use_layernorm=use_layernorm,
            device=device,
        )

    def __call__(self, x: torch.Tensor) -> Output:
        x = self._transpose_cnn(x)
        x = self.mse(x, squash=lambda x: x / 255)
        return x
