from dataclasses import dataclass, field
from typing import Any, List, Optional, Dict

import torch.nn as nn
from jaxtyping import Float
from torch import Tensor

from sf3d.models.network import get_activation
from sf3d.models.utils import BaseModule


@dataclass
class HeadSpec:
    name: str
    out_channels: int
    n_hidden_layers: int
    output_activation: Optional[str] = None
    output_bias: float = 0.0
    add_to_decoder_features: bool = False
    shape: Optional[List[int]] = None


class MultiHeadEstimator(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        triplane_features: int = 1024

        n_layers: int = 2
        hidden_features: int = 512
        activation: str = "relu"

        pool: str = "max"
        # Literal["mean", "max"] = "mean"  # noqa: F821

        heads: List[HeadSpec] = field(default_factory=lambda: [])

    cfg: Config

    def configure(self):
        layers = []
        cur_features = self.cfg.triplane_features * 3
        for _ in range(self.cfg.n_layers):
            layers.append(
                nn.Conv2d(
                    cur_features,
                    self.cfg.hidden_features,
                    kernel_size=3,
                    padding=0,
                    stride=2,
                )
            )
            layers.append(self.make_activation(self.cfg.activation))

            cur_features = self.cfg.hidden_features

        self.layers = nn.Sequential(*layers)

        assert len(self.cfg.heads) > 0
        heads = {}
        for head in self.cfg.heads:
            head_layers = []
            for i in range(head.n_hidden_layers):
                head_layers += [
                    nn.Linear(
                        self.cfg.hidden_features,
                        self.cfg.hidden_features,
                    ),
                    self.make_activation(self.cfg.activation),
                ]
            head_layers += [
                nn.Linear(
                    self.cfg.hidden_features,
                    head.out_channels,
                ),
            ]
            heads[head.name] = nn.Sequential(*head_layers)
        self.heads = nn.ModuleDict(heads)

    def make_activation(self, activation):
        if activation == "relu":
            return nn.ReLU(inplace=True)
        elif activation == "silu":
            return nn.SiLU(inplace=True)
        else:
            raise NotImplementedError

    def forward(
        self,
        triplane: Float[Tensor, "B 3 F Ht Wt"],
    ) -> Dict[str, Any]:
        x = self.layers(
            triplane.reshape(
                triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1]
            )
        )

        if self.cfg.pool == "max":
            x = x.amax(dim=[-2, -1])
        elif self.cfg.pool == "mean":
            x = x.mean(dim=[-2, -1])
        else:
            raise NotImplementedError

        out = {
            ("decoder_" if head.add_to_decoder_features else "")
            + head.name: get_activation(head.output_activation)(
                self.heads[head.name](x) + head.output_bias
            )
            for head in self.cfg.heads
        }
        for head in self.cfg.heads:
            if head.shape:
                head_name = (
                    "decoder_" if head.add_to_decoder_features else ""
                ) + head.name
                out[head_name] = out[head_name].reshape(*head.shape)

        return out
