import os
import torch
import torch.distributed as dist
import torch.utils.data

import pointcept.utils.comm as comm
from pointcept.utils.visualization import save_point_cloud
from pointcept.datasets import build_dataset, collate_fn

from ..default import HookBase
from ..builder import HOOKS


@HOOKS.register_module()
class PCAEvaluator(HookBase):
    def __init__(
        self,
        dataset,
        batch_size_per_gpu=1,
        num_workers_per_gpu=1,
        eval_step=10,
        point_size=0.03,
        write_tb=False,
        write_ply=True,
    ):
        assert batch_size_per_gpu == 1
        dataset = build_dataset(dataset)
        if comm.get_world_size() > 1:
            sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        else:
            sampler = None
        self.loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size_per_gpu,
            shuffle=False,
            num_workers=num_workers_per_gpu,
            pin_memory=True,
            sampler=sampler,
            collate_fn=collate_fn,
        )
        self.eval_step = eval_step
        self.point_size = point_size
        self.write_tb = write_tb
        self.write_ply = write_ply

    @staticmethod
    def pca_color(feat):
        feat_sum = feat.sum(dim=-2, keepdim=True)
        feat_num = torch.tensor(feat.shape[0], device=feat.device)
        if comm.get_world_size() > 1:
            comm.synchronize()
            dist.all_reduce(feat_sum)
            dist.all_reduce(feat_num)
        feat_mean = feat_sum / feat_num
        feat = feat - feat_mean
        u, s, v = torch.pca_lowrank(feat, center=False, q=3)
        projection = feat @ v
        min_val = projection.min(dim=-2, keepdim=True)[0]
        max_val = projection.max(dim=-2, keepdim=True)[0]
        if comm.get_world_size() > 1:
            comm.synchronize()
            dist.all_reduce(min_val, op=dist.ReduceOp.MIN)
            dist.all_reduce(max_val, op=dist.ReduceOp.MAX)
        div = torch.clamp(max_val - min_val, min=1e-6)
        return (projection - min_val) / div

    @staticmethod
    def save_ply(data_dict, save_path):
        # Save major view and local views as .ply files
        subfolder = data_dict["name"]
        if not os.path.exists(f"./concerto_visualization/{subfolder}"):
            os.makedirs(f"./concerto_visualization/{subfolder}")

        # Save images
        from PIL import Image

        for i, img in enumerate(data_dict["imgs"]):
            img = img * 255
            img_path = os.path.join(
                f"./concerto_visualization/{subfolder}/",
                f"img_{i}.png",
            )
            img_pil = Image.fromarray(img.permute(1, 2, 0).numpy().astype("uint8"))
            img_pil.save(img_path)

        # Save major view
        for i in range(self.global_view_num):
            major_view_coord = view_dict["global_coord"][i]
            major_view_color = view_dict["global_color"][i]
            major_view_ply = trimesh.PointCloud(
                major_view_coord, colors=major_view_color
            )
            major_view_ply.export(
                os.path.join(
                    f"./concerto_visualization/{subfolder}/",
                    f"major_view_{i}.ply",
                )
            )

        # Save local views
        for i in range(self.local_view_num):
            local_view_coord = view_dict["local_coord"][i]
            local_view_color = view_dict["local_color"][i]
            local_view_ply = trimesh.PointCloud(
                local_view_coord, colors=local_view_color
            )
            local_view_ply.export(
                os.path.join(
                    f"./concerto_visualization/{subfolder}/",
                    f"local_view_{i}.ply",
                )
            )

    def after_epoch(self):
        if (self.trainer.epoch + 1) % self.eval_step == 0:
            self.eval()

    def eval(self):
        self.trainer.logger.info(
            ">>>>>>>>>>>>>>>> Start PCA Evaluation >>>>>>>>>>>>>>>>"
        )
        self.trainer.model.eval()
        record = {}
        for i, input_dict in enumerate(self.loader):
            for key in input_dict.keys():
                if isinstance(input_dict[key], torch.Tensor):
                    input_dict[key] = input_dict[key].cuda(non_blocking=True)
            with torch.no_grad():
                point = self.trainer.model(input_dict, return_point=True)["point"]
                while "pooling_parent" in point.keys():
                    assert "pooling_inverse" in point.keys()
                    parent = point.pop("pooling_parent")
                    inverse = point.pop("pooling_inverse")
                    parent.feat = point.feat[inverse]
                    point = parent
                coord = point.coord
                color = self.pca_color(point.feat)

                record[point.name[0]] = dict(
                    coord=coord.unsqueeze(0).cpu().numpy(),
                    color=color.unsqueeze(0).cpu().numpy(),
                    split=point.split[0],
                )
            self.trainer.logger.info(
                "PCA: [{iter}/{max_iter}] ".format(
                    iter=i + 1, max_iter=len(self.loader)
                )
            )
        comm.synchronize()
        record_sync = comm.gather(record, dst=0)

        if comm.is_main_process():
            record = {}
            for _ in range(len(record_sync)):
                r = record_sync.pop()
                record.update(r)
                del r

            if self.write_ply:
                save_path = os.path.join(self.trainer.cfg.save_path, "pca")
                os.makedirs(save_path, exist_ok=True)

            for name in record.keys():
                split = record[name]["split"]
                coord = record[name]["coord"]
                color = record[name]["color"]
                if self.write_tb and self.trainer.writer is not None:
                    self.trainer.writer.add_mesh(
                        f"pca_{split}/{name}",
                        vertices=coord,
                        colors=color * 255,
                        config_dict={
                            "material": {
                                "cls": "PointsMaterial",
                                "size": self.point_size,
                            }
                        },
                        global_step=self.trainer.epoch + 1,
                    )
                    self.trainer.writer.flush()
                    self.trainer.logger.info(f"Add {split}/{name} to tensorboard")
                if self.write_ply:
                    save_ply(data_dict, save_path)
                    self.trainer.logger.info(f"Write {name} to local")

        comm.synchronize()
        self.trainer.logger.info(
            "<<<<<<<<<<<<<<<<< End PCA Evaluation <<<<<<<<<<<<<<<<<"
        )
