import math
from typing import Optional

import torch
import torch.nn as nn
from transformers import AutoModel

""" Reference for patching: https://huggingface.co/recursionpharma/OpenPhenom/blob/main/RxRx3-core_inference.ipynb
"""


class OpenPhenomEncoder(nn.Module):
    def __init__(
        self,
        feature_dim: Optional[int] = 384,
        patch_dim: Optional[int] = 16,
        channels: Optional[int] = 6,
        crops: Optional[int] = 4,
        img_dim: Optional[int] = 512,
        return_channelwise_embeddings: bool = False,
    ):
        super(OpenPhenomEncoder, self).__init__()

        self.encoder = AutoModel.from_pretrained(
            "recursionpharma/OpenPhenom", trust_remote_code=True
        )
        self.encoder.return_channelwise_embeddings = return_channelwise_embeddings
        self.patch_size = 256
        self.feature_dim = feature_dim
        self.patch_dim = patch_dim
        self.channels = channels
        self.crops = crops
        self.img_dim = img_dim
        self.cropped_img_dim = self.img_dim // (self.crops // 2)  # 512 / 2 = 256
        self.patches = (self.cropped_img_dim // self.patch_dim) * (
            self.cropped_img_dim // self.patch_dim
        )  # (256 / 16) * (256 / 16) = 256
        self.instance_norm = nn.InstanceNorm2d(self.channels, affine=False, eps=1e-6)
        # self.instance_norm = nn.Identity()
        self.embed_dim = feature_dim * channels

    def iter_border_patches(self, width: int, height: int, patch_size: int):
        x_start, x_end, y_start, y_end = (0, width, 0, height)

        for x in range(x_start, x_end - patch_size + 1, patch_size):
            for y in range(y_start, y_end - patch_size + 1, patch_size):
                yield x, y

    def patch_image(self, image_array: torch.Tensor) -> torch.Tensor:
        _, width, height = image_array.shape
        output_patches = []
        for x, y in self.iter_border_patches(width, height, self.patch_size):
            patch = image_array[
                :, y : y + self.patch_size, x : x + self.patch_size
            ].clone()
            output_patches.append(patch)

        output_patches = torch.stack(output_patches)
        return output_patches

    def cropify(self, im: torch.Tensor) -> torch.Tensor:
        """
        Crops an image tensor of shape (B, 6, 512, 512) into 4 non-overlapping patches,
        each of shape (B*4, 6, 256, 256).

        Parameters
        ----------
        im : torch.Tensor
            Input image tensor of shape (B, 6, 512, 512).

        Returns
        -------
        torch.Tensor
            Cropped patches of shape (B*4, 6, 256, 256).
        """
        B = im.shape[0]
        # Reshape into 6 chunks along height and width.
        img = im.view(
            B,
            self.channels,
            self.crops // 2,
            self.cropped_img_dim,
            self.crops // 2,
            self.cropped_img_dim,
        )
        # Permute to bring the crop dimensions forward: (B, 2, 2, 6, 256, 256)
        img = img.permute(0, 2, 4, 1, 3, 5)
        # Flatten the crop dimensions into the batch dimension.
        return img.reshape(
            B * (self.crops // 2) * (self.crops // 2),
            self.channels,
            self.cropped_img_dim,
            self.cropped_img_dim,
        )

    def unflatten_tokens(
        self,
        tokens: torch.Tensor,
        patch_size: int,
        num_modalities: int = 1,
        channel_agnostic: bool = False,
    ) -> torch.Tensor:
        """
        Unflattens tokens (N, L, patch_size**2 * C) into image tensor (N, C, H, W)
        where H = W = sqrt(L / num_modalities) * patch_size.

        Parameters
        ----------
        tokens : torch.Tensor
            Input token tensor of shape (N, L, patch_size**2 * C).
        patch_size : int
            The side length of each patch.
        num_modalities : int, optional
            Number of modalities (used as the number of channels in this case).
        channel_agnostic : bool, optional
            If True, interprets the tokens channel-agnostically (default: False).

        Returns
        -------
        img : torch.Tensor
            Reconstructed image tensor of shape (N, C, H, W).
        """
        # Compute the spatial grid size of tokens per modality.
        grid_tokens = tokens.shape[1] // num_modalities
        h = w = int(math.sqrt(grid_tokens))
        if h * w != grid_tokens:
            raise ValueError(
                "The square root of the number of tokens per modality is not an integer."
            )

        if channel_agnostic:
            # Reshape to (N, num_modalities, h, w, patch_size, patch_size)
            # Here "-1" infers the num_modalities since: tokens.shape[1] == num_modalities * h * w.
            x = tokens.reshape(tokens.shape[0], -1, h, w, patch_size, patch_size)
            # Permute to (N, num_modalities, h, patch_size, w, patch_size)
            x = x.permute(0, 1, 2, 4, 3, 5)
        else:
            # For non-channel-agnostic case.
            x = tokens.reshape(tokens.shape[0], h, w, patch_size, patch_size, -1)
            x = x.permute(0, 5, 1, 3, 2, 4)
        # Merge the spatial patches back together.
        img = x.reshape(x.shape[0], -1, h * patch_size, w * patch_size)
        return img

    def uncropify(self, crops: torch.Tensor) -> torch.Tensor:
        """
        Reassembles cropped patches (B*4, 6, 256, 256) into the original image tensor (B, 6, 512, 512).

        Parameters
        ----------
        crops : torch.Tensor
            Cropped patches of shape (B*4, 6, 256, 256).

        Returns
        -------
        torch.Tensor
            Reconstructed image tensor of shape (B, 6, 512, 512).
        """
        B = crops.shape[0] // self.crops
        # Reshape to restore the spatial grid: (B, 2, 2, 6, 256, 256)
        img = crops.view(
            B,
            self.crops // 2,
            self.crops // 2,
            self.channels,
            self.cropped_img_dim,
            self.cropped_img_dim,
        )
        # Inverse the permutation applied in cropify.
        img = img.permute(0, 3, 1, 4, 2, 5)
        # Reshape to merge the patches back into full image dimensions.
        return img.reshape(B, self.channels, self.img_dim, self.img_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the OpenPhenom encoder.

        Parameters
        ----------
        x : torch.Tensor
            Input image tensor of shape (B, C, H, W).

        Returns
        -------
        torch.Tensor
            Features of shape (B, D).


        """
        B, C, H, W = x.shape
        n_crops = (H // self.patch_size) * (W // self.patch_size)
        x = self.instance_norm(x)
        x = torch.vstack(
            [self.patch_image(i) for i in x]
        )  # convert to 4 256x256 patches
        out = self.encoder.predict(x)  # (B*4, 384)
        out = out.view(B, n_crops, out.shape[1]).mean(
            dim=1
        )  # average over 4 256x256 crops per image
        return out

    def encode_latent(
        self,
        x: torch.Tensor,
        collapse_crops: Optional[bool] = True,
    ) -> torch.Tensor:
        """
        Encodes the input image tensor into a sequence of tokens.

        Parameters
        ----------
        x : torch.Tensor
            Input image tensor of shape (B, C, H, W).

        Returns
        -------
        torch.Tensor
            Encoded token sequence of shape (B, C, H', W', D) or (B, 4, C, H'', W'', D).
        """
        B = x.shape[0]
        normed_x = self.instance_norm(x)
        cropped_x = self.cropify(normed_x)
        latent, _, ind_restore = self.encoder.encoder.forward_masked(cropped_x, 0.0)
        latent = latent.view(
            B, self.crops, self.channels * self.patches + 1, self.feature_dim
        )
        latent = latent[:, :, 1:, :].view(
            B, self.crops, self.channels, self.patches, self.feature_dim
        )
        patch_dim = int(math.sqrt(self.patches))

        if collapse_crops:
            latent = latent.permute(0, 2, 1, 3, 4)  # (B, C, CROPS, PATCHES, D)
            latent = latent.reshape(
                B,
                self.channels,  # => 6
                (self.crops // 2) * patch_dim,  # => 2 * 16 = 32
                (self.crops // 2) * patch_dim,  # => 2 * 16 = 32
                self.feature_dim,  # => 384
            )  # (B, C, 32, 32, D)
        else:
            latent = latent.reshape(
                B, self.crops, self.channels, patch_dim, patch_dim, self.feature_dim
            )
        return latent

    def encode(self, x: torch.Tensor):
        """
        Encodes the input image tensor into a sequence of tokens.

        Parameters
        ----------
        x : torch.Tensor
            Input image tensor of shape (B, C, H, W).

        Returns
        -------
        latent: torch.Tensor
            Encoded token sequence of shape (B, 4, C, H'', W'', D).
        ind_restore: torch.Tensor
            Index tensor to restore the original order of tokens after decoding.
        """
        normed_x = self.instance_norm(x)
        cropped_x = self.cropify(normed_x)
        latent, _, ind_restore = self.encoder.encoder.forward_masked(cropped_x, 0.0)
        return latent, ind_restore

    def decode(
        self,
        latent: torch.Tensor,
        ind_restore: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Decodes the input token sequence into an image tensor.

        Parameters
        ----------
        tokens : torch.Tensor
            Input token tensor of shape (B, C, H', W', D) or (B, 4, C, H'', W'', D).
        collapse_crops : bool, optional
            If True, collapses the crops into a single image tensor (default: True).

        Returns
        -------
        torch.Tensor
            Reconstructed image tensor of shape (B, C, H, W).
        """
        if ind_restore is None:
            B = latent.shape[0]
            ind_restore = (
                torch.arange(0, (self.channels * self.patches), device=latent.device)
                .unsqueeze(0)
                .repeat(B * self.crops, 1)
            )
        reconstruction = self.encoder.decode_to_reconstruction(
            latent,
            ind_restore,
            self.encoder.encoder_decoder_proj,
            self.encoder.decoder,
            self.encoder.decoder_pred,
        )
        # 5) Unflatten tokens from (B*crops, #tokens, patch_size^2) to (B*crops, channels, H, W)
        x_hat = self.unflatten_tokens(
            reconstruction,
            patch_size=self.patch_dim,
            num_modalities=self.channels,
            channel_agnostic=True,
        )
        # shape: (B*crops, channels, cropped_height, cropped_width)
        # Then uncropify back to (B, channels, orig_height, orig_width)
        x_hat = self.uncropify(x_hat)  # shape: (B, channels, H, W)
        return x_hat

    def encode_decode(self, x: torch.Tensor):
        """
        Encodes the input image tensor into a sequence of tokens and decodes it back into an image tensor.

        Parameters
        ----------
        x : torch.Tensor
            Input image tensor of shape (B, C, H, W).

        Returns
        -------
        torch.Tensor
            Reconstructed image tensor of shape (B, C, H, W).
        """
        latent, ind_restore = self.encode(x)
        return self.decode(latent, ind_restore)

    def forward_features(self, x: torch.Tensor, flatten: bool = True):
        """
        Forward pass of the OpenPhenom encoder to extract features.

        Parameters
        ----------
        x : torch.Tensor
            Input image tensor of shape (B, C, H, W).

        Returns
        -------
        torch.Tensor
            Features of shape (B, D).
        """
        latent = self.encode_latent(x)  # B, 6, 32, 32, 384
        # take avg pool -> B, 6, 384
        latent = latent.mean(dim=[2, 3])
        # return B, 6*384
        if flatten:
            latent = latent.view(latent.shape[0], -1)
        return latent


if __name__ == "__main__":
    encoder = OpenPhenomEncoder()
    x = torch.randn(2, 6, 512, 512)
    out = encoder(x)
    print(out.shape)
    latent = encoder.encode_latent(x)
    print(latent.shape)
    out = encoder.encode_decode(x)
    print(out.shape)
