import torch

from .configs import GINRWithXAttnDecoderConfig
from .modules.data_encoders import ImageEncoder
from .modules.transformer import TransformerEncoder
from .modules.decoder_with_xattn import DecoderWithCrossAttention
from .modules.coord_samplers import GridCoordSampler


class GINRWithXAttnDecoder(torch.nn.Module):

    def __init__(self, config: GINRWithXAttnDecoderConfig):
        super().__init__()
        self.config = config = config.copy()  # type: GINRWithXAttnDecoderConfig

        self.coord_sampler = GridCoordSampler(config.coord_sampler)

        if config.coord_sampler.data_type == "image":
            self.data_encoder = ImageEncoder(config.data_encoder)
        else:
            raise ValueError(f"Unsupported data type: {config.coord_sampler.data_type}")

        self.transformer = TransformerEncoder(
            config=config.transformer,
            input_dim=self.data_encoder.output_dim,
            output_dim=config.output_dim if config.use_latent_projection else config.latent_dim,
            num_input_tokens=config.num_data_tokens,
            num_latent_tokens=config.num_latent_tokens,
        )
        if config.use_latent_projection:
            self.proj_latent = torch.nn.Linear(config.output_dim, config.latent_dim)
        else:
            self.proj_latent = torch.nn.Identity()

        self.decoder = DecoderWithCrossAttention(config.decoder)

    def forward(self, xs, coord=None, keep_xs_shape=True):
        """
        Args:
            xs (torch.Tensor): (B, input_dim, *xs_spatial_shape)
            coord (torch.Tensor): (B, *coord_spatial_shape)
            keep_xs_shape (bool): If True, the outputs of hyponet (MLPs) is permuted and reshaped as `xs`
              If False, it returns the outputs of MLPs with channel_last data type (i.e. `outputs.shape == coord.shape`)
        Returns:
            outputs (torch.Tensor): predicted features at `coord`
        """
        coord = self.sample_coord_input(xs) if coord is None else coord

        data_tokens = self.data_encoder(xs, put_channels_last=True)
        latent_vectors = self.transformer(data_tokens)
        latent_vectors = self.proj_latent(latent_vectors)

        # Decoder predicts all pixels of `coord` by cross-attending to `latent_vectors`
        outputs = self.decoder(coord, latents=latent_vectors)
        if keep_xs_shape:
            permute_idx_range = [i for i in range(1, xs.ndim - 1)]
            outputs = outputs.permute(0, -1, *permute_idx_range)
        return outputs
    
    def encode_latents(self, xs, use_projection=False):
        data_tokens = self.data_encoder(xs, put_channels_last=True)
        latent_vectors = self.transformer(data_tokens)
        if use_projection:
            latent_vectors = self.proj_latent(latent_vectors)
        return latent_vectors 

    def decode_latents(self, latents, xs, apply_latent_projection=True, coord=None, keep_xs_shape=True):
        coord = self.sample_coord_input(xs) if coord is None else coord
        if apply_latent_projection:
            latent_vectors = self.proj_latent(latents)
        else:
            latent_vectors = latents

        outputs = self.decoder(coord, latents=latent_vectors)
        if keep_xs_shape:
            permute_idx_range = [i for i in range(1, xs.ndim - 1)]
            outputs = outputs.permute(0, -1, *permute_idx_range)
        return outputs

    def sample_coord_input(self, xs, coord_range=None, upsample_ratio=1.0, device=None):
        device = device if device is not None else xs.device
        coord_inputs = self.coord_sampler(xs, coord_range, upsample_ratio, device)
        return coord_inputs

    def compute_loss(self, preds, targets, reduction="mean"):
        assert reduction in ["mean", "sum", "none"]
        batch_size = preds.shape[0]
        sample_mses = torch.reshape((preds - targets) ** 2, (batch_size, -1)).mean(dim=-1)

        if reduction == "mean":
            total_loss = sample_mses.mean()
            psnr = (-10 * torch.log10(sample_mses)).mean()
        elif reduction == "sum":
            total_loss = sample_mses.sum()
            psnr = (-10 * torch.log10(sample_mses)).sum()
        else:
            total_loss = sample_mses
            psnr = -10 * torch.log10(sample_mses)

        return {"loss_total": total_loss, "mse": total_loss, "psnr": psnr}
