from dataclasses import dataclass, field
from typing import Any, List, Optional, Dict

import open_clip
import torch
import torch.nn as nn
from jaxtyping import Float
from torch import Tensor
from torchvision.transforms import Normalize

from sf3d.models.network import get_activation
from sf3d.models.utils import BaseModule


@dataclass
class HeadSpec:
    name: str
    out_channels: int
    n_hidden_layers: int
    output_activation: Optional[str] = None
    output_bias: float = 0.0
    add_to_decoder_features: bool = False
    shape: Optional[List[int]] = None


class ClipBasedHeadEstimator(BaseModule):
    @dataclass
    class Config(BaseModule.Config):
        model: str = "ViT-B-32"
        pretrain: str = "laion2b_s34b_b79k"

        distribution: str = "beta"

        # ["mean", "mode", "sample", "sample_mean"]
        distribution_eval: str = "mode"

        activation: str = "relu"
        hidden_features: int = 512
        heads: List[HeadSpec] = field(default_factory=lambda: [])

    cfg: Config

    def configure(self):
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            self.cfg.model, pretrained=self.cfg.pretrain
        )
        self.model.eval()

        # Do not add the weights in self.model to the optimizer
        for param in self.model.parameters():
            param.requires_grad = False

        assert len(self.cfg.heads) > 0
        heads = {}
        for head in self.cfg.heads:
            head_layers = []

            for i in range(head.n_hidden_layers):
                head_layers += [
                    nn.Linear(
                        self.cfg.hidden_features,
                        self.cfg.hidden_features,
                    ),
                    self.make_activation(self.cfg.activation),
                ]

            head_layers = [nn.Sequential(*head_layers)]
            head_layers += [
                nn.Sequential(
                    nn.Linear(
                        self.cfg.hidden_features,
                        self.cfg.hidden_features,
                    ),
                    self.make_activation(self.cfg.activation),
                    nn.Linear(self.cfg.hidden_features, 1),
                )
                for _ in range(2)
            ]
            heads[head.name] = nn.ModuleList(head_layers)
        self.heads = nn.ModuleDict(heads)

    def make_activation(self, activation):
        if activation == "relu":
            return nn.ReLU(inplace=True)
        elif activation == "silu":
            return nn.SiLU(inplace=True)
        else:
            raise NotImplementedError

    def forward(
        self,
        cond_image: Float[Tensor, "B 1 H W 3"],
        sample: bool = True,
    ) -> Dict[str, Any]:
        # Run the model
        # Resize cond_image to 224
        cond_image = nn.functional.interpolate(
            cond_image.flatten(0, 1).permute(0, 3, 1, 2),
            size=(224, 224),
            mode="bilinear",
            align_corners=False,
        )
        cond_image = Normalize(
            mean=open_clip.constants.OPENAI_DATASET_MEAN,
            std=open_clip.constants.OPENAI_DATASET_STD,
        )(cond_image)
        image_features = self.model.encode_image(cond_image)

        # Run the heads
        outputs = {}

        for head_dict in self.cfg.heads:
            head_name = head_dict.name
            shared_head, d1_h, d2_h = self.heads[head_name]
            shared_features = shared_head(image_features)
            d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]]
            if self.cfg.distribution == "normal":
                mean = d1
                var = d2
                if mean.shape[-1] == 1:
                    outputs[head_name] = torch.distributions.Normal(
                        mean + head_dict.output_bias,
                        torch.nn.functional.softplus(var),
                    )
                else:
                    outputs[head_name] = torch.distributions.MultivariateNormal(
                        mean + head_dict.output_bias,
                        torch.nn.functional.softplus(var).diag_embed(),
                    )
            elif self.cfg.distribution == "beta":
                outputs[head_name] = torch.distributions.Beta(
                    torch.nn.functional.softplus(d1 + head_dict.output_bias),
                    torch.nn.functional.softplus(d2 + head_dict.output_bias),
                )
            else:
                raise NotImplementedError

        if sample:
            for head_dict in self.cfg.heads:
                head_name = head_dict.name
                dist = outputs[head_name]

                if self.cfg.distribution_eval == "mean":
                    out = dist.mean
                elif self.cfg.distribution_eval == "mode":
                    out = dist.mode
                elif self.cfg.distribution_eval == "sample_mean":
                    out = dist.sample([10]).mean(-1)
                else:
                    # use rsample if gradient is needed
                    out = dist.rsample() if self.training else dist.sample()

                outputs[head_name] = get_activation(head_dict.output_activation)(out)
                outputs[f"{head_name}_dist"] = dist

        for head in self.cfg.heads:
            if head.shape:
                if not sample:
                    raise ValueError(
                        "Cannot reshape non-sampled probabilisitic outputs"
                    )
                outputs[head.name] = outputs[head.name].reshape(*head.shape)

            if head.add_to_decoder_features:
                outputs[f"decoder_{head.name}"] = outputs[head.name]
                del outputs[head.name]

        return outputs
