import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
from matplotlib import pyplot as plt
from torch.autograd import grad
from collections import OrderedDict
from timm.utils import AverageMeter
from tqdm import tqdm

from .defaults import create_ddp_model
import pointcept.utils.comm as comm
from pointcept.datasets import build_dataset, collate_fn
from pointcept.models import build_model
from pointcept.utils.logger import get_root_logger
from pointcept.utils.registry import Registry
from pointcept.utils.misc import make_dirs

VISUALIZER = Registry("visualizer")


class VisualizerBase:
    def __init__(self, cfg, model=None, test_loader=None, verbose=False) -> None:
        torch.multiprocessing.set_sharing_strategy("file_system")
        self.logger = get_root_logger(
            log_file=os.path.join(cfg.save_path, "visualize.log"),
            file_mode="a" if cfg.resume else "w",
        )
        self.logger.info("=> Loading config ...")
        self.cfg = cfg
        self.verbose = verbose
        if self.verbose:
            self.logger.info(f"Save path: {cfg.save_path}")
            self.logger.info(f"Config:\n{cfg.pretty_text}")
        if model is None:
            self.logger.info("=> Building model ...")
            self.model = self.build_model()
        else:
            self.model = model
        if test_loader is None:
            self.logger.info("=> Building test dataset & dataloader ...")
            self.test_loader = self.build_test_loader()
        else:
            self.test_loader = test_loader

    def build_model(self):
        model = build_model(self.cfg.model)
        n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
        self.logger.info(f"Num params: {n_parameters}")
        model = create_ddp_model(
            model.cuda(),
            broadcast_buffers=False,
            find_unused_parameters=self.cfg.find_unused_parameters,
        )
        if os.path.isfile(self.cfg.weight):
            self.logger.info(f"Loading weight at: {self.cfg.weight}")
            checkpoint = torch.load(self.cfg.weight)
            weight = OrderedDict()
            for key, value in checkpoint["state_dict"].items():
                if key.startswith("module."):
                    if comm.get_world_size() == 1:
                        key = key[7:]  # module.xxx.xxx -> xxx.xxx
                else:
                    if comm.get_world_size() > 1:
                        key = "module." + key  # xxx.xxx -> module.xxx.xxx
                weight[key] = value
            model.load_state_dict(weight, strict=True)
            self.logger.info(
                "=> Loaded weight '{}' (epoch {})".format(
                    self.cfg.weight, checkpoint["epoch"]
                )
            )
        else:
            raise RuntimeError("=> No checkpoint found at '{}'".format(self.cfg.weight))
        return model

    def build_test_loader(self):
        test_dataset = build_dataset(self.cfg.data.test)
        if comm.get_world_size() > 1:
            test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
        else:
            test_sampler = None
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=self.cfg.batch_size_test_per_gpu,
            shuffle=False,
            num_workers=self.cfg.batch_size_test_per_gpu,
            pin_memory=True,
            sampler=test_sampler,
            collate_fn=self.__class__.collate_fn,
        )
        return test_loader

    def compute_rollout_attention(self, all_layer_matrices, start_layer=0):
        num_tokens = all_layer_matrices[3].shape[1]
        batch_size = all_layer_matrices[3].shape[0]
        eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[3].device)

        all_layer_matrices = all_layer_matrices[3] + eye
        # matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
        #                 for i in range(len(all_layer_matrices))]
        joint_attention = all_layer_matrices
        # for i in range(start_layer + 1, len(matrices_aug)):
        #     joint_attention = matrices_aug[i].bmm(joint_attention)
        return joint_attention

    def generate_raw_attn(self, model, point_cloud, start_layer=3):
        point_cloud['feat'].requires_grad_()
        logits = model(point_cloud)["seg_logits"]
        all_layer_attentions = []
        for layer in model.backbone.dec:
            attn_heads = layer.GlobalMamba0.mixer.xai_vector
            attn_heads = (attn_heads - attn_heads.min()) / (attn_heads.max() - attn_heads.min())
            avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
            all_layer_attentions.append(avg_heads)
        p = all_layer_attentions[3]
        return p.clamp(min=0).squeeze().unsqueeze(0), logits

    def generate_mamba_attr(self, model, point_cloud, start_layer=15):
        point_cloud['feat'].requires_grad_()
        logits = model(point_cloud)["seg_logits"]
        logits = F.softmax(logits, -1)
        index = np.argmax(logits.cpu().data.numpy(), axis=-1)
        one_hot = np.zeros_like(logits.cpu().data.numpy())
        for i in range(index.shape[0]):
            one_hot[i, index[i]] = 1
        one_hot = torch.from_numpy(one_hot).to(logits.device).float()
        loss = torch.sum(one_hot * logits)
        model.zero_grad()
        loss.backward(retain_graph=True)
        all_layer_attentions = []
        for layer in model.backbone.dec:
            attn_heads = layer.GlobalMamba0.mixer.xai_vector
            s = layer.GlobalMamba0.get_gradients().squeeze().detach()
            s = s.clamp(min=0).max(dim=1)[0].unsqueeze(0)
            s = (s - s.min()) / (s.max() - s.min())
            attn_heads = (attn_heads - attn_heads.min()) / (attn_heads.max() - attn_heads.min())
            avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
            fused = avg_heads * s
            all_layer_attentions.append(fused)
        rollout = self.compute_rollout_attention(all_layer_attentions, start_layer)
        p = rollout.mean(dim=1).unsqueeze(0)
        return p.clamp(min=0).squeeze().unsqueeze(0), logits

    def generate_rollout(self, model, point_cloud, start_layer=15, num_layers=24):
        point_cloud['feat'].requires_grad_()
        logits = model(point_cloud)["seg_logits"]
        all_layer_attentions = []
        for layer in range(num_layers):
            attn_heads = model.layers[layer].mixer.xai_b
            avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
            all_layer_attentions.append(avg_heads)
        rollout = self.compute_rollout_attention(all_layer_attentions, start_layer=start_layer)
        p = rollout.mean(dim=1).unsqueeze(0)
        return p.clamp(min=0).squeeze().unsqueeze(0), logits

    def generate_visualization(self, point_cloud, transformer_attribution, save_path):
        # 确保 transformer_attribution 和 point_cloud['coord'] 形状一致
        coords = point_cloud['coord'].detach().cpu().numpy()
        features = point_cloud['feat'].detach().cpu().numpy()
        # assert transformer_attribution.shape[0] == coords.shape[0]
        transformer_attribution = transformer_attribution.permute(1, 0)

        transformer_attribution = transformer_attribution.cpu().numpy()
        print(transformer_attribution)
        # 将坐标、特征和注意力权重保存到 npy 文件中
        np.savez(save_path, coord=coords, feat=features, attr=transformer_attribution)

        return save_path

    def get_input_grad(self, feat, pred):
        out_size = pred.size()
        central_point = torch.nn.functional.relu(pred[out_size[0] // 2, :]).sum()
        grad = torch.autograd.grad(central_point, feat, allow_unused=True)[0]
        grad = torch.nn.functional.relu(grad)
        # grad = grad / (grad.norm() + 1e-6)
        aggregated = grad.sum(1)
        # grad_map = aggregated.cpu().numpy()

        return aggregated

    def test(self):
        raise NotImplementedError

    @staticmethod
    def collate_fn(batch):
        raise collate_fn(batch)


@VISUALIZER.register_module()
class ERFVisualizer(VisualizerBase):
    def visualize(self):
        assert self.test_loader.batch_size == 1
        logger = get_root_logger()
        logger.info(">>>>>>>>>>>>>>>> Start Visualization >>>>>>>>>>>>>>>>")

        self.model.eval()

        save_path = os.path.join(self.cfg.save_path, "visualization")
        make_dirs(save_path)

        if (
                self.cfg.data.test.type == "ScanNetDataset"
                or self.cfg.data.test.type == "ScanNet200Dataset"
        ) and comm.is_main_process():
            make_dirs(os.path.join(save_path, "visual"))
        elif (
                self.cfg.data.test.type == "SemanticKITTIDataset" and comm.is_main_process()
        ):
            make_dirs(os.path.join(save_path, "visual"))
        elif self.cfg.data.test.type == "NuScenesDataset" and comm.is_main_process():
            import json

            make_dirs(os.path.join(save_path, "visual", "lidarseg", "test"))
            make_dirs(os.path.join(save_path, "visual", "test"))
            submission = dict(
                meta=dict(
                    use_camera=False,
                    use_lidar=True,
                    use_radar=False,
                    use_map=False,
                    use_external=False,
                )
            )
            with open(
                    os.path.join(save_path, "visual", "test", "visual.json"), "w"
            ) as f:
                json.dump(submission, f, indent=4)
        comm.synchronize()
        # meter = AverageMeter()
        for idx, data_dict in enumerate(self.test_loader):
            data_dict = data_dict[0]  # current assume batch size is 1
            fragment_list = data_dict.pop("fragment_list")
            segment = data_dict.pop("segment")
            data_name = data_dict.pop("name")
            vis_save_path = os.path.join(save_path, "{}_vis".format(data_name))

            grad_map = torch.zeros(segment.size).cuda()
            for i in tqdm(range(1)):
                fragment_batch_size = 1
                s_i, e_i = i * fragment_batch_size, min(
                    (i + 1) * fragment_batch_size, len(fragment_list)
                )
                input_dict = collate_fn(fragment_list[s_i:e_i])

                for key in input_dict.keys():
                    if isinstance(input_dict[key], torch.Tensor):
                        input_dict[key] = input_dict[key].cuda(non_blocking=True)

                input_dict['feat'].requires_grad_()

                idx_part = input_dict["index"]

                pred_part = self.model(input_dict)["seg_logits"]  # (n, k)
                # pred_part = F.softmax(pred_part, -1)

                grad_map_part = self.get_input_grad(input_dict['feat'], pred_part)
                # grad_map_matrix = self.get_input_grad_matrix(input_dict['feat'], pred_part)

                np.savez(vis_save_path + '_grad_map',
                         attr=grad_map_part.cpu().numpy(),
                         coord=input_dict['coord'].cpu().numpy())

                if self.cfg.empty_cache:
                    torch.cuda.empty_cache()

            # if np.isnan(np.sum(grad_map)):
            #     print("got nan | ", end="")
            #     continue
            # else:
                # assert transformer_attribution.shape[0] == coords.shape[0]
                # 将坐标、特征和注意力权重保存到 npy 文件中


        logger.info("<<<<<<<<<<<<<<<<< End Visualization <<<<<<<<<<<<<<<<<")

    @staticmethod
    def collate_fn(batch):
        return batch
