from functools import partial

import wandb
from src.logger.jam_wandb import prefix_metrics_keys
from src.models.base_model import BaseModule
from src.viz.point_clouds import plot_batch_pcd

# pylint: disable=abstract-method,too-many-ancestors,arguments-renamed,line-too-long,arguments-differ,unused-argument


class PointCloudModule(BaseModule):
    def get_real_data(self, batch):
        _, partial_pcd, complete_pcd = batch
        # partial_pcd: (B, 2048, 3)
        # complete_pcd: (B, 2048, 3)
        src_data, _ = partial_pcd.chunk(2)
        _, trg_data = complete_pcd.chunk(2)
        if self.global_step == 1:
            plot_pcd = partial(plot_batch_pcd, n_col=8)
            plot_pcd(src_data.detach().cpu(), "source.png")
            plot_pcd(trg_data.detach().cpu(), "target.png")
            if wandb.run is not None:
                wandb.log(
                    {
                        "train data": [
                            wandb.Image("source.png", caption="source"),
                            wandb.Image("target.png", caption="target"),
                        ]
                    }
                )
        return src_data, trg_data

    def validation_step(self, batch, _):
        # evaluate Chamfer loss between generated and target
        _, src_data, trg_data = batch
        if self.cfg.ema:
            with self.ema_map.average_parameters():
                tx_data = self.map_t(src_data)
        src_trg_chamfer = self.cost_func(src_data, trg_data)
        src_pf_chamfer = self.cost_func(src_data, tx_data)
        pf_trg_chamfer = self.cost_func(tx_data, trg_data)
        log_info = prefix_metrics_keys(
            {
                "baseline chamfer": src_trg_chamfer,
                "chamfer with source": src_pf_chamfer,
                "difference from baseline chamfer": abs(
                    src_trg_chamfer - src_pf_chamfer
                ),
                "chamfer with target": pf_trg_chamfer,
            },
            "validation_loss",
        )
        self.log_dict(log_info)
