import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.distributed as dist
import torch.utils.data
import torch.nn.functional as F

import pointcept.utils.comm as comm
from pointcept.utils.visualization import save_point_cloud, save_lines
from pointcept.datasets import build_dataset, collate_fn

from ..default import HookBase
from ..builder import HOOKS


@HOOKS.register_module()
class InternalMatchingEvaluator(HookBase):
    """
    Internal matching visual evaluation with local and global view of point clouds.
    """

    def __init__(
        self,
        dataset,
        oral=0.05,
        highlight=0.2,
        reject=0.5,
        cmap="inferno",
        eval_step=10,
        point_size=0.03,
        segment_ignore_index=None,
        write_tb=False,
        write_ply=True,
        **kwargs,
    ):
        dataset = build_dataset(dataset)
        if comm.get_world_size() > 1:
            sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        else:
            sampler = None
        self.loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=1,
            shuffle=False,
            num_workers=1,
            pin_memory=True,
            sampler=sampler,
            collate_fn=collate_fn,
        )
        self.oral = oral
        self.highlight = highlight
        self.reject = reject
        self.cmap = plt.get_cmap(cmap)
        self.eval_step = eval_step
        self.point_size = point_size
        self.segment_ignore_index = segment_ignore_index
        self.write_tb = write_tb
        self.write_ply = write_ply

    @staticmethod
    def pca_color(feat):
        feat_sum = feat.sum(dim=-2, keepdim=True)
        feat_num = torch.tensor(feat.shape[0], device=feat.device)
        if comm.get_world_size() > 1:
            comm.synchronize()
            dist.all_reduce(feat_sum)
            dist.all_reduce(feat_num)
        feat_mean = feat_sum / feat_num
        feat = feat - feat_mean
        u, s, v = torch.pca_lowrank(feat, center=False, q=3)
        projection = feat @ v
        min_val = projection.min(dim=-2, keepdim=True)[0]
        max_val = projection.max(dim=-2, keepdim=True)[0]
        if comm.get_world_size() > 1:
            comm.synchronize()
            dist.all_reduce(min_val, op=dist.ReduceOp.MIN)
            dist.all_reduce(max_val, op=dist.ReduceOp.MAX)
        div = torch.clamp(max_val - min_val, min=1e-6)
        return (projection - min_val) / div

    def after_epoch(self):
        if (self.trainer.epoch + 1) % self.eval_step == 0:
            self.eval()

    def eval(self):
        self.trainer.logger.info(
            ">>>>>>>>>>> Start Internal Matching Evaluation >>>>>>>>>>>"
        )
        self.trainer.model.eval()
        record = {}
        for i, input_dict in enumerate(self.loader):
            global_point = dict()
            local_point = dict()
            for key in input_dict.keys():
                if isinstance(input_dict[key], torch.Tensor):
                    input_dict[key] = input_dict[key].cuda(non_blocking=True)
                if key.startswith("global_"):
                    global_point[key.replace("global_", "")] = input_dict[key]
                elif key.startswith("local_"):
                    local_point[key.replace("local_", "")] = input_dict[key]
                else:
                    global_point[key] = input_dict[key]
                    local_point[key] = input_dict[key]

            with torch.no_grad():
                global_point = self.trainer.model(global_point, return_point=True)[
                    "point"
                ]
                local_point = self.trainer.model(local_point, return_point=True)[
                    "point"
                ]
                if (
                    self.segment_ignore_index is not None
                    and "segment" in local_point.keys()
                ):
                    local_mask = ~torch.isin(
                        local_point.segment,
                        torch.tensor(
                            self.segment_ignore_index, device=local_point.segment.device
                        ),
                    )
                else:
                    local_mask = torch.ones_like(local_point.coord[:, 0]).bool()

                valid_indices = torch.nonzero(local_mask).squeeze()
                sampled_indices = valid_indices[
                    torch.randperm(valid_indices.size(0))[:1]
                ]
                target = F.normalize(local_point.feat, p=2, dim=-1)
                refer = F.normalize(global_point.feat, p=2, dim=-1)
                inner_self = target[sampled_indices] @ target.t()
                inner_cross = target[sampled_indices] @ refer.t()

                sorted_inner = torch.sort(inner_cross, descending=True)[0]
                oral = sorted_inner[0, int(global_point.offset[0] * self.oral)]
                highlight = sorted_inner[
                    0, int(global_point.offset[0] * self.highlight)
                ]
                reject = sorted_inner[0, -int(global_point.offset[0] * self.reject)]

                inner_self = inner_self - highlight
                inner_self[inner_self > 0] = F.sigmoid(
                    inner_self[inner_self > 0] / (oral - highlight)
                )
                inner_self[inner_self < 0] = (
                    F.sigmoid(inner_self[inner_self < 0] / (highlight - reject)) * 0.9
                )

                inner_cross = inner_cross - highlight
                inner_cross[inner_cross > 0] = F.sigmoid(
                    inner_cross[inner_cross > 0] / (oral - highlight)
                )
                inner_cross[inner_cross < 0] = (
                    F.sigmoid(inner_cross[inner_cross < 0] / (highlight - reject)) * 0.9
                )

                matched_indices = torch.argmax(inner_cross)

            local_coord = local_point.coord.cpu().numpy()
            local_coord[:, 1] = local_coord[:, 1] - local_coord[:, 1].max() - 0.05
            local_coord[:, 0] = local_coord[:, 0] - local_coord[:, 0].max() - 0.05
            local_color = local_point.color.cpu().numpy()
            local_heat_color = self.cmap(inner_self.squeeze(0).cpu().numpy())[:, :3]

            global_coord = global_point.coord.cpu().numpy()
            global_coord[:, 1] = global_coord[:, 1] - global_coord[:, 1].max() - 0.05
            global_coord[:, 0] = global_coord[:, 0] - global_coord[:, 0].min() + 0.05
            global_color = global_point.color.cpu().numpy()
            global_heat_color = self.cmap(inner_cross.squeeze(0).cpu().numpy())[:, :3]

            coord = np.expand_dims(
                np.concatenate(
                    [
                        local_coord,
                        global_coord,
                        local_coord * np.array([[-1, -1, 1]]),
                        global_coord * np.array([[-1, -1, 1]]),
                    ]
                ),
                axis=0,
            )
            color = np.expand_dims(
                np.concatenate(
                    [local_heat_color, global_heat_color, local_color, global_color]
                ),
                axis=0,
            )
            matched_coord = np.stack(
                [local_coord[sampled_indices], global_coord[matched_indices]]
            )
            matched_line = np.array([[0, 1]])

            name = local_point.name[0]
            split = local_point.split[0]
            record[name] = dict(
                coord=coord,
                color=color,
                split=split,
                matched_coord=matched_coord,
                matched_line=matched_line,
            )

            self.trainer.logger.info(
                "Internal Matching: [{iter}/{max_iter}] ".format(
                    iter=i + 1, max_iter=len(self.loader)
                )
            )
        comm.synchronize()
        record_sync = comm.gather(record, dst=0)

        if comm.is_main_process():
            record = {}
            for _ in range(len(record_sync)):
                r = record_sync.pop()
                record.update(r)
                del r
            if self.write_ply:
                save_path = os.path.join(
                    self.trainer.cfg.save_path, "internal_matching"
                )
                os.makedirs(save_path, exist_ok=True)
            for name in record.keys():
                coord = record[name]["coord"]
                color = record[name]["color"]
                split = record[name]["split"]
                matched_coord = record[name]["matched_coord"]
                matched_line = record[name]["matched_line"]
                if self.write_tb and self.trainer.writer is not None:
                    self.trainer.writer.add_mesh(
                        f"im_{split}/{name}",
                        vertices=coord,
                        colors=color * 255,
                        config_dict={
                            "material": {
                                "cls": "PointsMaterial",
                                "size": self.point_size,
                            }
                        },
                        global_step=self.trainer.epoch + 1,
                    )
                    self.trainer.writer.flush()
                    self.trainer.logger.info(f"Add {name} to tensorboard")
                if self.write_ply:
                    point_save_path = os.path.join(
                        save_path, f"{split}_{name}_{self.trainer.epoch + 1}.ply"
                    )
                    line_save_path = os.path.join(
                        save_path, f"{split}_{name}_{self.trainer.epoch + 1}_line.ply"
                    )
                    save_point_cloud(coord[0], color[0], file_path=point_save_path)
                    save_lines(matched_coord, matched_line, file_path=line_save_path)
                    self.trainer.logger.info(f"Write {name} to local")

        comm.synchronize()
        self.trainer.logger.info(
            ">>>>>>>>>>> End Internal Matching Evaluation >>>>>>>>>>>"
        )
