import math
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from einops.layers.torch import Reduce
from gate.boilerplate.decorators import configurable
from gate.models.backbones import (
    GATEImageEncoder,
    GATEImageTextEncoder,
    GATETextEncoder,
)
from gate.models.backbones.timm import CLIPModelPaths, GATECLIPTextEncoder
from rich import print

from einspace.compiler import Compiler
from einspace.search_spaces import EinSpace
from einspace.utils import millify


class Network(nn.Module):
    """A network that takes architectural modules and wraps them with a stem and a head."""

    def __init__(self, backbone, backbone_output_shape, output_shape):
        super(Network, self).__init__()
        self.backbone = backbone
        self.stem = nn.Sequential(
            # conv stem to even number of channels?
            # positional embedding?
            nn.Identity()
        )
        if len(backbone_output_shape) == 3:
            self.head = nn.Sequential(
                Reduce("b s d -> b s", "mean"),
                nn.Linear(backbone_output_shape[1], output_shape),
            )
        elif len(backbone_output_shape) == 4:
            self.head = nn.Sequential(
                Reduce("b c h w -> b c", "mean"),
                nn.Linear(backbone_output_shape[1], output_shape),
            )
        self.backbone_output_shape = backbone_output_shape

    def forward(self, x):
        out = self.stem(x)
        out = self.backbone(out)
        out = self.head(out)
        return out

    def numel(self):
        num_params = sum([p.numel() for p in self.parameters()])
        return num_params

    def num_parameters(self):
        return f"Num params: {millify(self.numel())}"


class ImageEincoder(GATEImageEncoder):
    def __init__(
        self,
        architecture_dict: dict,
        full_image_shape: list | tuple,
        num_projection_features: Optional[int] = None,
    ):
        super(ImageEincoder, self).__init__()
        self.backbone = Compiler().compile(architecture_dict)
        self.full_image_shape = full_image_shape  # (c, h, w)
        self._num_projection_features = num_projection_features

        self.build()

    def build(self):
        x_dummy = torch.zeros((1, *self.full_image_shape))
        out = x_dummy
        out = self.backbone(out)
        backbone_output_shape = out.shape
        self.backbone_output_shape = backbone_output_shape
        if self.num_projection_features is not None:
            if len(backbone_output_shape) == 3:
                self.head = nn.Sequential(
                    Reduce("b s d -> b s", "mean"),
                    nn.Linear(
                        backbone_output_shape[1], self.num_projection_features
                    ),
                )
            elif len(backbone_output_shape) == 4:
                self.head = nn.Sequential(
                    Reduce("b c h w -> b c", "mean"),
                    nn.Linear(
                        backbone_output_shape[1], self.num_projection_features
                    ),
                )
            out = self.head(out)

        print(
            f"Built Eincoder with output shape: {out.shape}, "
            f"backbone output shape: {backbone_output_shape}, "
            f"and backbone: {self.backbone}"
        )

    def forward(self, x):
        out = self.backbone(x)

        return {"features": out, "raw_features": out}

    @property
    def projection_layer(self):
        return self.head if self.num_projection_features is not None else None

    @property
    def num_projection_features(self):
        return self._num_projection_features

    @property
    def num_features(self):
        return self.backbone_output_shape[1]

    @property
    def num_raw_features(self):
        return self.backbone_output_shape[1]

    @property
    def image_shape(self):
        return self.full_image_shape[1:]

    def transforms(self, x):
        transform_list = [
            transforms.Resize(self.image_shape[1:]),
            transforms.ToTensor(),
        ]
        transform_fused = transforms.Compose(transform_list)
        if isinstance(x, torch.Tensor):
            return transform_fused(x)
        elif isinstance(x, list):
            return [transform_fused(xi) for xi in x]


@configurable(
    group="encoder",
    name="eincoder",
)
class ImageTextEincoder(GATEImageTextEncoder, nn.Module):
    def __init__(
        self,
        architecture_dict: dict,
        clip_model_name: str = CLIPModelPaths.openai_b_16,
        image_size: Optional[int] = 224,
        num_projection_features: Optional[int] = None,
    ):
        nn.Module.__init__(self)
        image_embedding = ImageEincoder(
            architecture_dict=architecture_dict,
            full_image_shape=(3, 224, 224),
            num_projection_features=num_projection_features,
        )
        text_embedding = GATECLIPTextEncoder(
            model_name=clip_model_name,
            num_projection_features=num_projection_features,
        )
        GATEImageTextEncoder.__init__(
            self,
            image_embedding,
            text_embedding,
            image_size,
            num_projection_features,
        )
