from functools import partial

import pytorch_lightning as pl  # pylint: disable=unused-import
import torch
from pytorch_lightning import Callback

import wandb
from src.viz.point_clouds import plot_batch_pcd


def test_push_point_clouds(test_batch, pl_module):
    with torch.no_grad():
        _, src_data, trg_data = test_batch
        src_data = src_data.to(pl_module.device)
        if pl_module.cfg.ema:
            with pl_module.ema_map.average_parameters():
                tx_data = pl_module.map_t(src_data)
        return src_data, tx_data, trg_data


class PointCloudViz(Callback):
    def __init__(self, log_interval, n_test_samples, emb_map_path) -> None:
        super().__init__()
        self.log_interval = log_interval
        self.n_test_samples = n_test_samples
        self.emb_map_path = emb_map_path

    def on_batch_start(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
    ) -> None:
        if pl_module.global_step % self.log_interval == 0:
            test_batch = trainer.datamodule.get_test_samples(self.n_test_samples)
            src_data, tx_data, trg_data = test_push_point_clouds(test_batch, pl_module)
            source_save_path = f"source_{pl_module.global_step}.png"
            pf_save_path = f"pf_{pl_module.global_step}.png"
            target_save_path = f"target_{pl_module.global_step}.png"
            plot_pcd = partial(plot_batch_pcd, n_col=3)
            plot_pcd(src_data.detach().cpu(), source_save_path)
            plot_pcd(tx_data.detach().cpu(), pf_save_path)
            plot_pcd(trg_data.detach().cpu(), target_save_path)
            if wandb.run is not None:
                wandb.log(
                    {
                        "generated point cloud": [
                            wandb.Image(source_save_path, caption="source"),
                            wandb.Image(pf_save_path, caption="pushforward"),
                            wandb.Image(target_save_path, caption="target"),
                        ]
                    }
                )
