import os
import cv2
import torch
import numpy as np

from models.yolo import Model
from utils.torch_utils import select_device
from utils.general import nms_visiual

from pytorch_grad_cam import LayerCAM
from pytorch_grad_cam.utils.image import show_cam_on_image


def decode(x, model):
    anchors = model.model[-1].anchors
    stride = model.model[-1].stride
    z = []
    for i in range(len(x)):
        bs, na, ny, nx, no = x[i].shape
        y = x[i].sigmoid().clone()
        grid_y, grid_x = torch.meshgrid(
            torch.arange(ny, device=x[i].device),
            torch.arange(nx, device=x[i].device)
        )
        grid = torch.stack((grid_x, grid_y), 2).to(x[i].device).float()
        grid = grid.view(1, 1, ny, nx, 2)
        anchor_grid = anchors[i].clone().view(1, na, 1, 1, 2).to(x[i].device) * stride[i]
        y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + grid) * stride[i]
        y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * anchor_grid
        z.append(y.view(bs, -1, no))
    return torch.cat(z, dim=1)


class SingleBoxTarget:
    def __init__(self, index):
        self.index = index

    def __call__(self, model_output):
        obj = torch.sigmoid(model_output[self.index, 4])
        cls = torch.sigmoid(model_output[self.index, 5:]).max()
        return obj * cls


class WrappedModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        rgb = x[:, :3]
        ir = x[:, 3:]
        output, _ = self.model(rgb, ir)
        decoded = decode(output[1], self.model)
        return decoded


def preprocess(path):
    img = cv2.imread(path)
    img = cv2.resize(img, (1024, 1024))
    img = img.astype(np.float32) / 255.0
    img = img.transpose(2, 0, 1)
    return torch.from_numpy(img).float(), img


def multi_target_heatmap(cam_class, wrapped_model, input_tensor, indices, base_image, save_path, target_layers):
    H, W = base_image.shape[:2]
    total_heatmap = np.zeros((H, W), dtype=np.float32)

    for idx in indices:
        target = SingleBoxTarget(int(idx))
        cam = cam_class(model=wrapped_model, target_layers=target_layers, use_cuda=True)
        grayscale_cam = cam(input_tensor=input_tensor, targets=[target])[0]
        total_heatmap += grayscale_cam
        del cam

    total_heatmap = (total_heatmap - total_heatmap.min()) / (total_heatmap.max() - total_heatmap.min() + 1e-6)

    if base_image.max() > 1.0:
        base_image = base_image / 255.0
    overlay = show_cam_on_image(base_image.astype(np.float32), total_heatmap, use_rgb=True)
    cv2.imwrite(save_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))


# ====== main ======
def main():
    RGB_DIR = ''
    IR_DIR = ''
    SAVE_DIR = './visiual/'
    
    os.makedirs(SAVE_DIR, exist_ok=True)

    WEIGHTS = ''
    DEVICE = select_device('cuda:0')

    ckpt = torch.load(WEIGHTS, map_location=DEVICE)
    cfg = ckpt['model'].yaml
    model = Model(cfg, ch=6).to(DEVICE)
    model.load_state_dict(ckpt['model'].float().state_dict(), strict=False)
    model.eval()

    wrapped_model = WrappedModel(model)

    fused_target_layers = [model.model[19]]

    image_list = sorted(os.listdir(RGB_DIR))

    for filename in image_list:
        if not filename.lower().endswith(('.jpg', '.png', '.jpeg', '.bmp')):
            continue

        rgb_path = os.path.join(RGB_DIR, filename)
        ir_path = os.path.join(IR_DIR, filename)

        if not os.path.exists(ir_path):
            print(f"[!] IR not found for {filename}")
            continue

        img_rgb_t, img_rgb = preprocess(rgb_path)
        img_ir_t, img_ir = preprocess(ir_path)

        input_tensor = torch.cat([img_rgb_t, img_ir_t], dim=0).unsqueeze(0).to(DEVICE)

        decoded = wrapped_model(input_tensor)
        nms_out = nms_visiual(decoded, conf_thres=0.3, iou_thres=0.4)[0]

        if nms_out.shape[0] == 0:
            print(f"[!] No detections in {filename}")
            continue

        indices = nms_out[:, -1].long()

        multi_target_heatmap(
            LayerCAM,
            wrapped_model,
            input_tensor,
            indices,
            np.transpose(img_rgb, (1, 2, 0)),  # 背景用 RGB
            os.path.join(SAVE_DIR, f"{os.path.splitext(filename)[0]}_salient.png"),
            fused_target_layers
        )

if __name__ == '__main__':
    main()
