# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
import random
from typing import List, Optional, Tuple
import matplotlib.pyplot as plt

import mmcv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.logging import print_log
from torch import Tensor, optim
from torchsparse import SparseTensor

from mmseg.registry import MODELS
from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
                         OptSampleList, SampleList, add_prefix)
from .base import BaseSegmentor
from ..UncerDetector import UncerDetector

@MODELS.register_module()
class ACDC(BaseSegmentor):
    """Encoder Decoder segmentors.

    EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
    Note that auxiliary_head is only used for deep supervision during training,
    which could be dumped during inference.

    1. The ``loss`` method is used to calculate the loss of model,
    which includes two steps: (1) Extracts features to obtain the feature maps
    (2) Call the decode head loss function to forward decode head model and
    calculate losses.

    .. code:: text

     loss(): extract_feat() -> _decode_head_forward_train() -> _auxiliary_head_forward_train (optional)
     _decode_head_forward_train(): decode_head.loss()
     _auxiliary_head_forward_train(): auxiliary_head.loss (optional)

    2. The ``predict`` method is used to predict segmentation results,
    which includes two steps: (1) Run inference function to obtain the list of
    seg_logits (2) Call post-processing function to obtain list of
    ``SegDataSample`` including ``pred_sem_seg`` and ``seg_logits``.

    .. code:: text

     predict(): inference() -> postprocess_result()
     infercen(): whole_inference()/slide_inference()
     whole_inference()/slide_inference(): encoder_decoder()
     encoder_decoder(): extract_feat() -> decode_head.predict()

    3. The ``_forward`` method is used to output the tensor by running the model,
    which includes two steps: (1) Extracts features to obtain the feature maps
    (2)Call the decode head forward function to forward decode head model.

    .. code:: text

     _forward(): extract_feat() -> _decode_head.forward()

    Args:

        backbone (ConfigType): The config for the backnone of segmentor.
        decode_head (ConfigType): The config for the decode head of segmentor.
        neck (OptConfigType): The config for the neck of segmentor.
            Defaults to None.
        auxiliary_head (OptConfigType): The config for the auxiliary head of
            segmentor. Defaults to None.
        train_cfg (OptConfigType): The config for training. Defaults to None.
        test_cfg (OptConfigType): The config for testing. Defaults to None.
        data_preprocessor (dict, optional): The pre-process config of
            :class:`BaseDataPreprocessor`.
        pretrained (str, optional): The path for pretrained model.
            Defaults to None.
        init_cfg (dict, optional): The weight initialized config for
            :class:`BaseModule`.
    """  # noqa: E501

    def __init__(self,
                 segmentor_d: ConfigType,
                 segmentor_c: ConfigType,
                 refiner: ConfigType,
                 uncer_detector: ConfigType,
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptMultiConfig = None,
                 entropy_threshold=0.5,
                 cloud_threshold=0.8,
                 upsample_percent=0.2,
                 scale=(256, 256)):
        super().__init__(
            data_preprocessor=data_preprocessor, init_cfg=init_cfg)

        self.segmentor_d = MODELS.build(segmentor_d)
        self.segmentor_c = MODELS.build(segmentor_c)

        self._init_refiner(refiner)

        self.uncer_detector = UncerDetector(model=self.segmentor_d, **uncer_detector)

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.entropy_threshold = entropy_threshold
        self.cloud_threshold = cloud_threshold
        self.upsample_percent = upsample_percent
        self.num_classes = self.segmentor_d.decode_head.num_classes
        self.align_corners = self.segmentor_d.decode_head.align_corners
        self.out_channels = self.segmentor_d.decode_head.out_channels
        self.scale = scale

        self.num_use_cloud = 0
        self.num_sample = 0



    def _init_refiner(self, refiner: ConfigType) -> None:
        """Initialize ``SparseRefine``"""
        self.refiner = MODELS.build(refiner)

    def extract_feat(self, inputs: Tensor, size) -> Tuple[Tensor]:
        """Extract features from images."""
        logits, x = self._segmentor_d_low_predict(inputs, size)
        return logits, x

    def _segmentor_d_low_predict(self, inputs, size):
        logits, resized_logits = self.segmentor_d.encode_decode_entropy(inputs, size, mode='nearest')
        return logits, resized_logits

    def encode_decode(self, inputs: Tensor,
                      batch_img_metas: List[dict]) -> Tensor:
        """Encode images with backbone and decode into a semantic segmentation
        map of the same size as input."""
        ### resize操作
        # 直接将 Tensor 转换为 Numpy 进行批量处理
        inputs_np = inputs.squeeze(0).permute(1, 2, 0).cpu().numpy()  # 转换为 HWC 格式，形状 (512, 512, 3)
        resized_np = mmcv.imrescale(inputs_np, self.scale)
        resized_inputs = torch.from_numpy(resized_np).permute(2, 0, 1).unsqueeze(0).to(inputs.device)  # 将结果转换回 Tensor，恢复形状

        img_shape = batch_img_metas[0]['img_shape']
        logits, x = self.extract_feat(resized_inputs, img_shape)
        x_entropy, selected_mask = self.upsample_entropy_selector(x, self.upsample_percent)
        device_logits = x

        selected_mask = selected_mask.unsqueeze(1)
        selected_pixels = inputs * selected_mask

        refined_x = self.refine(selected_pixels)

        uncertainty_score = self.uncer_detector(x)[0] * 100
        # if uncertainty_score > self.cloud_threshold:
        #     self.num_use_cloud += 1
        #     x = self.segmentor_c.encode_decode(inputs, batch_img_metas)
        #     # self.uncer_detector.fit_tta(device_logits, x, 0.1)


        final_x = x + refined_x * selected_mask
        self.num_sample += 1
        # print(f'num_use_cloud = {self.num_use_cloud}, num_sample = {self.num_sample}')

        seg_logits = self.segmentor_d.decode_head.predict_by_feat(final_x, batch_img_metas)

        return seg_logits

    def _decode_head_forward_train(self, inputs: List[Tensor],
                                   data_samples: SampleList) -> dict:
        """Run forward function and calculate loss for decode head in
        training."""
        losses = dict()
        loss_decode = self.segmentor_d.decode_head.loss_by_feat(inputs, data_samples)

        losses.update(add_prefix(loss_decode, 'decode'))
        return losses

    def refine(self, selected_pixels):
        features, coords = self.convert_to_sparse_tensor(selected_pixels)
        refined_features = self.refiner.backbone(features, coords)
        refined_x = self.refiner.decode_head.forward(refined_features)

        # 重建输出张量
        batch_size, _, h, w = selected_pixels.shape
        output = torch.zeros(batch_size, self.num_classes, h, w, device=refined_x.device)

        # 构建索引张量
        batch_indices = coords[:, 3].long()
        y_indices = coords[:, 1].long()
        x_indices = coords[:, 0].long()

        # 使用索引填充输出张量
        output[batch_indices, :, y_indices, x_indices] = refined_x

        return output

    def convert_to_sparse_tensor(self, image_tensor):
        """
        Efficiently convert a 2D dense image tensor of shape (batch_size, 3, H, W)
        into a SparseTensor suitable for torchsparse MinkUNetBackbone.

        Args:
            image_tensor (torch.Tensor): Input tensor with shape (batch_size, 3, H, W).

        Returns:
            features (Tensor): Voxel features in shape (N, C).
            coords (Tensor): Coordinates in shape (N, 4),
                the columns in the order of (x_idx, y_idx, z_idx, batch_idx).
        """
        batch_size, channels, height, width = image_tensor.shape

        # Flatten the spatial dimensions (H, W) to simplify operations
        flattened_tensor = image_tensor.view(batch_size, channels, -1)

        # Identify non-zero pixels across all channels
        non_zero_mask = flattened_tensor.sum(dim=1) != 0  # (batch_size, H * W)

        # Get the coordinates of non-zero pixels
        batch_indices, flat_indices = non_zero_mask.nonzero(as_tuple=True)
        y_coords = flat_indices // width
        x_coords = flat_indices % width

        # Stack the coordinates in the required format (batch_idx, z=0, y, x)
        # coords = torch.stack([batch_indices, torch.zeros_like(batch_indices), y_coords, x_coords], dim=1)
        coords = torch.stack([x_coords, y_coords, torch.zeros_like(batch_indices), batch_indices], dim=1)

        # Extract corresponding features
        features = flattened_tensor[batch_indices, :, flat_indices].view(-1, channels)


        return features, coords.to(torch.int32)

    def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
        """Calculate losses from a batch of inputs and data samples.

        Args:
            inputs (Tensor): Input images.
            data_samples (list[:obj:`SegDataSample`]): The seg data samples.
                It usually includes information such as `metainfo` and
                `gt_sem_seg`.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        batch_size, channels, height, width = inputs.shape

        resized_inputs_list = []
        for i in range(batch_size):
            img_np = inputs[i].permute(1, 2, 0).cpu().numpy()
            resized_img_np = mmcv.imrescale(img_np, self.scale)
            resized_img_tensor = torch.from_numpy(resized_img_np).permute(2, 0, 1).to(inputs.device)

            resized_inputs_list.append(resized_img_tensor)
        resized_inputs = torch.stack(resized_inputs_list, dim=0)

        pad_shape = data_samples[0].pad_shape
        _, x = self.extract_feat(resized_inputs, pad_shape)
        x_entropy, selected_mask = self.upsample_entropy_selector(x, self.upsample_percent)

        selected_mask = selected_mask.unsqueeze(1)
        selected_pixels = inputs * selected_mask

        refined_x = self.refine(selected_pixels)

        final_x = x + refined_x * selected_mask

        losses = dict()

        loss_decode = self._decode_head_forward_train(final_x, data_samples)
        losses.update(loss_decode)

        return losses

    def predict(self,
                inputs: Tensor,
                data_samples: OptSampleList = None) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing.

        Args:
            inputs (Tensor): Inputs with shape (N, C, H, W).
            data_samples (List[:obj:`SegDataSample`], optional): The seg data
                samples. It usually includes information such as `metainfo`
                and `gt_sem_seg`.

        Returns:
            list[:obj:`SegDataSample`]: Segmentation results of the
            input images. Each SegDataSample usually contain:

            - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
            - ``seg_logits``(PixelData): Predicted logits of semantic
                segmentation before normalization.
        """
        if data_samples is not None:
            batch_img_metas = [
                data_sample.metainfo for data_sample in data_samples
            ]
        else:
            batch_img_metas = [
                dict(
                    ori_shape=inputs.shape[2:],
                    img_shape=inputs.shape[2:],
                    pad_shape=inputs.shape[2:],
                    padding_size=[0, 0, 0, 0])
            ] * inputs.shape[0]

        seg_logits = self.inference(inputs, batch_img_metas)

        return self.postprocess_result(seg_logits, data_samples)

    def _forward(self,
                 inputs: Tensor,
                 data_samples: OptSampleList = None) -> Tensor:
        """Network forward process.

        Args:
            inputs (Tensor): Inputs with shape (N, C, H, W).
            data_samples (List[:obj:`SegDataSample`]): The seg
                data samples. It usually includes information such
                as `metainfo` and `gt_sem_seg`.

        Returns:
            Tensor: Forward output of model without any post-processes.
        """
        inputs_np = inputs.squeeze(0).permute(1, 2, 0).cpu().numpy()  # 转换为 HWC 格式，形状 (512, 512, 3)
        resized_np = mmcv.imrescale(inputs_np, self.scale)
        resized_inputs = torch.from_numpy(resized_np).permute(2, 0, 1).unsqueeze(0).to(
            inputs.device)  # 将结果转换回 Tensor，恢复形状

        img_shape = inputs.shape[2:]
        logits, x = self.extract_feat(resized_inputs, img_shape)
        _, use_cloud = self.model_entropy_selector(logits, self.entropy_threshold, self.cloud_threshold)
        x_entropy, selected_mask = self.upsample_entropy_selector(x, self.upsample_percent)

        selected_mask = selected_mask.unsqueeze(1)
        selected_pixels = inputs * selected_mask

        refined_x = self.refine(selected_pixels)
        # final_x = x + refined_x * selected_mask

        use_cloud = torch.rand(1, device=inputs.device)
        if use_cloud < 0.2:
            self.num_use_cloud += 1
            x = self.segmentor_c.encode_decode(inputs, data_samples)
            # final_x = self.segmentor_c.encode_decode(inputs, batch_img_metas)

        final_x = x + refined_x * selected_mask
        return final_x

    def slide_inference(self, inputs: Tensor,
                        batch_img_metas: List[dict]) -> Tensor:
        """Inference by sliding-window with overlap.

        If h_crop > h_img or w_crop > w_img, the small patch will be used to
        decode without padding.

        Args:
            inputs (tensor): the tensor should have a shape NxCxHxW,
                which contains all images in the batch.
            batch_img_metas (List[dict]): List of image metainfo where each may
                also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
                'ori_shape', and 'pad_shape'.
                For details on the values of these keys see
                `mmseg/datasets/pipelines/formatting.py:PackSegInputs`.

        Returns:
            Tensor: The segmentation results, seg_logits from model of each
                input image.
        """

        h_stride, w_stride = self.test_cfg.stride
        h_crop, w_crop = self.test_cfg.crop_size
        batch_size, _, h_img, w_img = inputs.size()
        out_channels = self.out_channels
        h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
        w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
        preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
        count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
        for h_idx in range(h_grids):
            for w_idx in range(w_grids):
                y1 = h_idx * h_stride
                x1 = w_idx * w_stride
                y2 = min(y1 + h_crop, h_img)
                x2 = min(x1 + w_crop, w_img)
                y1 = max(y2 - h_crop, 0)
                x1 = max(x2 - w_crop, 0)
                crop_img = inputs[:, :, y1:y2, x1:x2]
                # change the image shape to patch shape
                batch_img_metas[0]['img_shape'] = crop_img.shape[2:]
                # the output of encode_decode is seg logits tensor map
                # with shape [N, C, H, W]
                crop_seg_logit = self.encode_decode(crop_img, batch_img_metas)
                preds += F.pad(crop_seg_logit,
                               (int(x1), int(preds.shape[3] - x2), int(y1),
                                int(preds.shape[2] - y2)))

                count_mat[:, :, y1:y2, x1:x2] += 1
        assert (count_mat == 0).sum() == 0
        seg_logits = preds / count_mat

        return seg_logits

    def whole_inference(self, inputs: Tensor,
                        batch_img_metas: List[dict]) -> Tensor:
        """Inference with full image.

        Args:
            inputs (Tensor): The tensor should have a shape NxCxHxW, which
                contains all images in the batch.
            batch_img_metas (List[dict]): List of image metainfo where each may
                also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
                'ori_shape', and 'pad_shape'.
                For details on the values of these keys see
                `mmseg/datasets/pipelines/formatting.py:PackSegInputs`.

        Returns:
            Tensor: The segmentation results, seg_logits from model of each
                input image.
        """

        seg_logits = self.encode_decode(inputs, batch_img_metas)

        return seg_logits

    def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
        """Inference with slide/whole style.

        Args:
            inputs (Tensor): The input image of shape (N, 3, H, W).
            batch_img_metas (List[dict]): List of image metainfo where each may
                also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
                'ori_shape', 'pad_shape', and 'padding_size'.
                For details on the values of these keys see
                `mmseg/datasets/pipelines/formatting.py:PackSegInputs`.

        Returns:
            Tensor: The segmentation results, seg_logits from model of each
                input image.
        """
        assert self.test_cfg.get('mode', 'whole') in ['slide', 'whole'], \
            f'Only "slide" or "whole" test mode are supported, but got ' \
            f'{self.test_cfg["mode"]}.'
        ori_shape = batch_img_metas[0]['ori_shape']
        if not all(_['ori_shape'] == ori_shape for _ in batch_img_metas):
            print_log(
                'Image shapes are different in the batch.',
                logger='current',
                level=logging.WARN)
        if self.test_cfg.mode == 'slide':
            seg_logit = self.slide_inference(inputs, batch_img_metas)
        else:
            seg_logit = self.whole_inference(inputs, batch_img_metas)

        return seg_logit

    def aug_test(self, inputs, batch_img_metas, rescale=True):
        """Test with augmentations.

        Only rescale=True is supported.
        """
        # aug_test rescale all imgs back to ori_shape for now
        assert rescale
        # to save memory, we get augmented seg logit inplace
        seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale)
        for i in range(1, len(inputs)):
            cur_seg_logit = self.inference(inputs[i], batch_img_metas[i],
                                           rescale)
            seg_logit += cur_seg_logit
        seg_logit /= len(inputs)
        seg_pred = seg_logit.argmax(dim=1)
        # unravel batch dim
        seg_pred = list(seg_pred)
        return seg_pred

    def upsample_entropy_selector(self, out, upsample_percent=0.2):
        batch_size, c, h, w = out.shape
        out = out.permute(0, 2, 3, 1)
        probs = torch.softmax(out, dim=-1)
        upsample_entropy = -torch.sum(probs * torch.log(probs), dim=-1)

        flat_x_entropy = upsample_entropy.view(upsample_entropy.size(0), -1)

        num_elements = h * w
        top_k = int(upsample_percent * num_elements)
        topk_values, _ = torch.topk(flat_x_entropy, top_k, dim=1, largest=True, sorted=False)
        upsample_threshold = topk_values.min(dim=1, keepdim=True)[0]

        upsample_threshold = upsample_threshold.unsqueeze(-1)

        return upsample_entropy, upsample_entropy >= upsample_threshold

    def model_entropy_selector(self, logits, entropy_threshold, cloud_threshold):
        logits = logits.permute(0, 2, 3, 1)  # (B, H, W, C)
        probs = torch.softmax(logits, dim=-1)
        model_entropy = -torch.sum(probs * torch.log(probs), dim=-1)  # (B, H, W)
        model_uncertain_count = torch.sum((model_entropy > entropy_threshold).float())
        model_uncertain_ratio = model_uncertain_count / torch.numel(model_entropy)

        return model_entropy, model_uncertain_ratio > cloud_threshold

