import torch
import torch.nn.functional as F
from omegaconf import OmegaConf

from utils.geometry import rays_to_plucker
from .configs import LightFieldWithXAttnDecoderConfig
from .modules.data_encoders import MultiviewEncoder
from .modules.transformer import TransformerEncoder

from .modules.decoder_with_xattn import DecoderWithCrossAttention


class LightFieldWithXAttnDecoder(torch.nn.Module):
    def __init__(self, config: LightFieldWithXAttnDecoderConfig):
        super().__init__()
        self.config = config = OmegaConf.to_object(config)  # type: LightFieldWithXAttnDecoderConfig

        self.data_encoder = MultiviewEncoder(config.data_encoder, use_plucker_coordinate=True)

        self.transformer = TransformerEncoder(
            config=config.transformer,
            input_dim=self.data_encoder.output_dim,
            output_dim=config.latent_dim,
            num_input_tokens=config.num_data_tokens,
            num_latent_tokens=config.num_latent_tokens,
        )

        self.decoder = DecoderWithCrossAttention(config.decoder)

    def forward(
        self,
        support_imgs,
        support_poses,
        support_focals,
        query_rays_o,
        query_rays_d,
    ):
        data_tokens = self.data_encoder(support_imgs, support_poses, support_focals, put_channels_last=True)
        latent_vectors = self.transformer(data_tokens)

        coord = self.sample_coord_input(query_rays_o, query_rays_d)
        outputs = self.decoder(coord, latents=latent_vectors)

        return outputs

    def forward_by_subbatch_ray(
        self,
        support_imgs,
        support_poses,
        support_focals,
        query_rays_o,
        query_rays_d,
        ray_subbatch_size=16384,
    ):
        data_tokens = self.data_encoder(support_imgs, support_poses, support_focals, put_channels_last=True)
        latent_vectors = self.transformer(data_tokens)

        outputs_list = []
        for idx in range(0, query_rays_o.shape[1], ray_subbatch_size):
            subrays_o = query_rays_o[:, idx : idx + ray_subbatch_size]
            subrays_d = query_rays_d[:, idx : idx + ray_subbatch_size]

            coord = self.sample_coord_input(subrays_o, subrays_d)
            outputs = self.decoder(coord, latents=latent_vectors)

            outputs_list.append(outputs)

        outputs_total = torch.cat(outputs_list, dim=1)
        return outputs_total

    def sample_coord_input(self, rays_o, rays_d, device=None):
        coord_inputs = rays_to_plucker(rays_o, rays_d)
        return coord_inputs

    def compute_loss(self, preds, targets, reduction="mean"):
        assert reduction in ["mean", "sum", "none"]
        batch_size = preds.shape[0]
        sample_mses = torch.reshape((preds - targets) ** 2, (batch_size, -1)).mean(dim=-1)

        if reduction == "mean":
            total_loss = sample_mses.mean()
            psnr = (-10 * torch.log10(sample_mses)).mean()
        elif reduction == "sum":
            total_loss = sample_mses.sum()
            psnr = (-10 * torch.log10(sample_mses)).sum()
        else:
            total_loss = sample_mses
            psnr = -10 * torch.log10(sample_mses)

        return {"loss_total": total_loss, "mse": total_loss, "psnr": psnr}
