import argparse
import logging
import os
import random
from pathlib import Path
import numpy as np
import torch
from torch.nn import functional as F
from tqdm import tqdm
from PIL import Image
import sys, os
# Set project root directory path
project_root = os.path.abspath(os.getcwd())
sys.path.append(project_root)

from CLIP_utils.factory import create_model_and_transforms, get_tokenizer
from prs_hook import hook_prs_logger
from others.prompt.visualization import visual_segmentation_process, visualize_bar
from torchvision.utils import save_image
from torchvision import transforms
import cv2

# Configure global logger
logger = logging.getLogger(__name__)

def set_seed(seed):
    """Set global random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_parser_info():
    """Create and return the argument parser with command-line options."""
    parser = argparse.ArgumentParser(
        description="Compute prediction accuracy for CLIP model on ImageNet datasets in different ways.",
        add_help=True
    )
    parser.add_argument("--model",default="ViT-L-14",type=str)
    parser.add_argument("--pretrained", default="laion2b_s32b_b82k", type=str, help="Pretrained weights to load.")
    parser.add_argument("--save_img", action="store_true", help="Save visualization images")
    parser.add_argument("--results_dir", default="others/prompt/results", help="Directory to save results")
    parser.add_argument("--attens_type", type=str, default="TDE", choices=["Grad-CAM", "TDE"])
    # parser.add_argument("--attens_type", type=str, default="Grad-CAM", choices=["Grad-CAM", "TDE"])
    parser.add_argument("--cuda_id", type=str, default="0", help="cuda id")
    return parser

def get_module(model, name):
    """Find and return a specific module in the model by name.
    
    Args:
        model: The PyTorch model to search in
        name: The name of the target module
        
    Returns:
        The requested module
        
    Raises:
        ValueError: If the module cannot be found
    """
    for n, m in model.named_modules():
        if n == name:
            return m
    raise ValueError(f"Module {name} not found")

class Probe:
    """Class for probing intermediate activations or inputs in a PyTorch module.
    
    This class implements hooks to capture intermediate data from the model during inference.
    """
    def __init__(self, module, target="output"):
        """Initialize a probe on a module.
        
        Args:
            module: PyTorch module to probe
            target: Whether to capture "output" or "input" of the module
        """
        self.data = []
        self.hook = module.register_forward_hook(self.hook_fn if target == "output" else self.hook_input_fn)

    def hook_fn(self, module, input, output):
        """Hook function for capturing module output."""
        self.data.append(output)

    def hook_input_fn(self, module, input, output):
        """Hook function for capturing module input."""
        self.data.append(input[0])

    def remove(self):
        """Remove the registered hook."""
        self.hook.remove()

def reshape_transform(tensor, height=None, width=None):
    """Reshape transformer output tensor to a format suitable for attention visualization.
    
    Args:
        tensor: Input tensor of shape [B, N, C] (B: batch size, N: sequence length, C: channels)
        height: Height of the output tensor (defaults to sqrt of n_patches)
        width: Width of the output tensor (defaults to sqrt of n_patches)
        
    Returns:
        Reshaped tensor of shape [B, C, height, width]
    """
    B, N, C = tensor.shape
    n_patches = N - 1
    if height is None or width is None:
        size = int(np.sqrt(n_patches))
        height = width = size
    assert n_patches == height * width, f"Patch count mismatch: {n_patches} != {height}*{width}"
    return tensor[:, 1:, :].permute(0, 2, 1).reshape(B, C, height, width)

def gradient_to_grad_cam_saliency(activation):
    """Convert gradient information to Grad-CAM saliency map.
    
    Args:
        activation: Activation tensor with gradient information
        
    Returns:
        Saliency map (only positive values)
    """
    # activation: with .grad
    return F.relu((activation * activation.grad).sum(dim=1, keepdim=True))  # shape: [B, 1, H, W]

def main(args):
    """Main function to compute and visualize attention for CLIP model.

    Args:
        args: Command-line arguments.
    """
    try:
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s - %(levelname)s - %(message)s",
        )
        device = torch.device(f"cuda:{args.cuda_id}" if torch.cuda.is_available() else "cpu")

        # Create results directory if it doesn't exist
        if args.save_img:
            os.makedirs(os.path.join(args.results_dir, "input"), exist_ok=True)
        alpha = 0.7  # Blend ratio between heatmap and original image
        # Load the model
        model, _, preprocess = create_model_and_transforms(args.model, pretrained=args.pretrained)
        model.to(device)
        model.eval()
        tokenizer = get_tokenizer(args.model)
        
        image_path = "others/prompt/results/input.jpg"
        image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)

        # 2. Text prompt construction
        prompt_list = ["a photo of a sheep."]
        text_tokens = tokenizer(prompt_list).to(device)
        
        if args.attens_type == "Grad-CAM":
            # Get the number of the last transformer block
            last_block = -1
            for name, _ in model.named_modules():
                if "visual.transformer.resblocks" in name and ".ln_1" in name:
                    block_num = int(name.split(".")[3])
                    last_block = max(last_block, block_num)
            
            if last_block == -1:
                logger.error("Cannot find transformer blocks in the model")
                return
                
            logger.info(f"Using last transformer block: {last_block}")
            
            # 1. Get target layer, using ln_1 from the last transformer block
            target_layer = f"visual.transformer.resblocks.{last_block}.ln_1"
            logger.info(f"Target layer: {target_layer}")
            saliency_layer = get_module(model, target_layer)

            # 3. Set up probe and gradients
            probe = Probe(saliency_layer, target="output")
            image.requires_grad = True  # Ensure input image requires gradients

            # 4. Forward pass to get logits
            with torch.set_grad_enabled(True):  # Explicitly enable gradient computation
                # Forward pass
                image_features = model.encode_image(image)
                text_features = model.encode_text(text_tokens)
                
                # Ensure intermediate features retain gradients
                for layer_output in probe.data:
                    layer_output.retain_grad()
                
                # Normalize features
                image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
                text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
                
                # Calculate similarity scores
                logit_scale = model.logit_scale.exp()
                logits_per_image = logit_scale * image_features_norm @ text_features_norm.T

                # 5. Backward pass
                model.zero_grad()
                logits_per_image[0, 0].backward(retain_graph=True)

                # 6. Grad-CAM calculation
                if not probe.data or probe.data[0].grad is None:
                    logger.error("Failed to get gradients, please check model setup")
                    # Print more debug information
                    logger.info(f"Probe data length: {len(probe.data)}")
                    if len(probe.data) > 0:
                        logger.info(f"First probe data shape: {probe.data[0].shape}")
                        logger.info(f"First probe data requires_grad: {probe.data[0].requires_grad}")
                        logger.info(f"First probe data is leaf: {probe.data[0].is_leaf}")
                        if hasattr(probe.data[0], 'grad'):
                            logger.info(f"Gradient shape: {probe.data[0].grad.shape if probe.data[0].grad is not None else 'None'}")
                    return

                # Ensure data is detached and requires gradients
                probe_data = probe.data[0].detach().clone()
                probe_data.requires_grad = True
                probe_t = reshape_transform(probe_data)
                
                if probe.data[0].grad is not None:
                    grad_data = probe.data[0].grad.detach().clone()
                    probe_t.grad = reshape_transform(grad_data)
                    saliency = gradient_to_grad_cam_saliency(probe_t)

                    # 7. Upsample heatmap to 224×224
                    saliency = F.interpolate(saliency, size=(224, 224), mode="bilinear", align_corners=True)
                    saliency = saliency.detach().cpu().numpy()[0, 0]
                    saliency = np.clip(saliency, 0, saliency.max())
                    saliency_norm = (saliency - saliency.min()) / (saliency.max() - saliency.min())
                    
                    # Use mean as threshold for segmentation (if mask needed)
                    thresh = saliency_norm.mean()  # Calculate mean after normalization to [0,1] range
                    saliency_mask = (saliency_norm > thresh).astype(np.uint8)  # 0/1 mask

                    # Optional: Save mask image (0/1 becomes 0/255)
                    cv2.imwrite(os.path.join(f"{args.results_dir}/{args.attens_type}", "gradcam_mask.jpg"), saliency_mask * 255)
                    

                    # Handle NaN/Inf values
                    if np.isnan(saliency).any() or np.isinf(saliency).any():
                        saliency = np.nan_to_num(saliency, nan=0.0, posinf=1.0, neginf=0.0)

                    # Normalize to [0, 255]
                    max_val = saliency.max()
                    if max_val > 1e-6:
                        saliency = np.clip(255.0 * saliency / max_val, 0, 255.0).astype(np.uint8)
                    else:
                        saliency = np.zeros_like(saliency, dtype=np.uint8)

                    # Generate heatmap
                    heatmap = cv2.applyColorMap(saliency, cv2.COLORMAP_JET)


                    # 8. Save visualization images
                    os.makedirs(f"{args.results_dir}/{args.attens_type}", exist_ok=True)
                    cv2.imwrite(os.path.join(f"{args.results_dir}/{args.attens_type}", "gradcam_heatmap.jpg"), heatmap)

                    # Overlay on original image (224×224)
                    original_img = cv2.imread(image_path)
                    if original_img is not None:
                        original_img = cv2.resize(original_img, (224, 224))

                        # Check heatmap dimensions, resize if inconsistent
                        if heatmap.shape[:2] != original_img.shape[:2]:
                            heatmap = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0]))

                        # Check number of channels, convert to 3 channels if heatmap is single channel
                        if len(heatmap.shape) == 2 or heatmap.shape[2] == 1:
                            heatmap = cv2.cvtColor(heatmap, cv2.COLOR_GRAY2BGR)

                        # Now dimensions and channels are consistent, can overlay
                        overlay = cv2.addWeighted(original_img, 1-alpha, heatmap, alpha, 0)
                        cv2.imwrite(os.path.join(f"{args.results_dir}/{args.attens_type}", "gradcam_overlay_original_size.jpg"), overlay)
                        # Output high-resolution version
                        high_res_size = (1024, 1024)
                        high_res_overlay = cv2.resize(overlay, high_res_size, interpolation=cv2.INTER_CUBIC)
                        cv2.imwrite(os.path.join(f"{args.results_dir}/{args.attens_type}", "gradcam_overlay.jpg"), high_res_overlay, [cv2.IMWRITE_JPEG_QUALITY, 95])

                    # Save original image to results folder
                    original_img = cv2.imread(image_path)
                    if original_img is not None:
                        high_res_size = (1024, 1024)
                        original_img = cv2.resize(original_img, high_res_size, interpolation=cv2.INTER_CUBIC)
                        cv2.imwrite(f"{args.results_dir}/{args.attens_type}/input.jpg", original_img)
        elif args.attens_type == "Path-atten":
            prs = hook_prs_logger(model, device)
            prs.reinit()
            with torch.no_grad():
                image_embeddings = model.encode_image(image.to(device), attn_method="head", normalize=False)
                text_features = model.encode_text(text_tokens)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                attentions, mlps = prs.finalize(representation=image_embeddings)
            # Take the last 12 layers of attention and calculate patch similarity with text features
            num_last_layers = 12
            attentions_collapse = attentions[:, -num_last_layers:, 1:].sum(axis=(1, 3))  # [b, n_patches, d]
            patch_sim = torch.matmul(attentions_collapse, text_features.T)  # [b, n_patches, 1]

            # Use patch_sim directly, without subtracting mean
            size = int(patch_sim.shape[1] ** 0.5)  # Square root of patch count
            sample_foreground = patch_sim[0, :, 0].reshape(size, size).cpu().numpy()  # [size, size]
            results = torch.from_numpy(sample_foreground).unsqueeze(0).unsqueeze(0)  # [1, 1, size, size]

            # Interpolate to image size
            patch_size = 14  # Adjust based on model
            Res = F.interpolate(results, scale_factor=patch_size, mode="bilinear")  # [1, 1, H, W]
            Res = torch.clip(Res, 0, Res.max())  # Clip negative values
            Res = (Res - Res.min()) / (Res.max() - Res.min())  # Normalize to [0, 1]

            # Generate binary mask
            ret = Res.mean()
            Res_1 = Res.gt(ret).type(Res.type())  # Foreground mask

            # Save mask
            if True:  # Can change to if args.save_img:
                os.makedirs(f"{args.results_dir}/{args.attens_type}", exist_ok=True)
                mask = Res_1[0].squeeze().cpu().numpy()  # [H, W]
                mask = (255 * mask).astype("uint8")
                cv2.imwrite(f"{args.results_dir}/{args.attens_type}/mask.jpg", mask)

                # Generate and save heatmap
                # relevance = F.interpolate(Res, size=(224, 224), mode="bilinear")  # Assuming image size is 224x224
                # relevance = relevance[0].squeeze().cpu().numpy()  # [H, W]
                relevance = Res.squeeze().cpu().numpy()  # [H, W]
                if np.isnan(relevance).any() or np.isinf(relevance).any():
                    relevance = np.nan_to_num(relevance, nan=0.0, posinf=1.0, neginf=0.0)
                max_val = relevance.max()
                if max_val > 1e-6:
                    relevance = np.clip(255.0 * relevance / max_val, 0, 255.0).astype(np.uint8)
                else:
                    relevance = np.zeros_like(relevance, dtype=np.uint8)
                high = cv2.applyColorMap(relevance, cv2.COLORMAP_JET)  # BGR format
                cv2.imwrite(f"{args.results_dir}/{args.attens_type}/heatmap.jpg", high)
                
                # Generate heatmap overlaid on original image
                original_img = cv2.imread(image_path)
                if original_img is not None:
                    original_img = cv2.resize(original_img, (224, 224))

                    # Check high dimensions, resize if inconsistent
                    if high.shape[:2] != original_img.shape[:2]:
                        high = cv2.resize(high, (original_img.shape[1], original_img.shape[0]))

                    # Check number of channels, convert to 3 channels if high is single channel
                    if len(high.shape) == 2 or high.shape[2] == 1:
                        high = cv2.cvtColor(high, cv2.COLOR_GRAY2BGR)

                    # Now dimensions and channels are consistent, can overlay
                    overlay = cv2.addWeighted(original_img, 1 - alpha, high, alpha, 0)
                    
                    # Save overlay image
                    overlay_path = f"{args.results_dir}/{args.attens_type}/overlay.jpg"
                    # Also save original resolution version
                    cv2.imwrite(f"{args.results_dir}/{args.attens_type}/overlay_original_size.jpg", overlay)
                                    
                    # Increase output image resolution
                    high_res_size = (1024, 1024)  # Increase to higher resolution
                    high_res_overlay = cv2.resize(overlay, high_res_size, interpolation=cv2.INTER_CUBIC)
                    
                    # Save high-resolution image
                    cv2.imwrite(overlay_path, high_res_overlay, [cv2.IMWRITE_JPEG_QUALITY, 95])
                    logger.info(f"Heatmap overlay image saved to: {overlay_path} (resolution: {high_res_size[0]}x{high_res_size[1]})")

                    # Save original image to results folder
                    original_img = cv2.imread(image_path)
                    if original_img is not None:
                        high_res_size = (1024, 1024)
                        original_img = cv2.resize(original_img, high_res_size, interpolation=cv2.INTER_CUBIC)
                        cv2.imwrite(f"{args.results_dir}/{args.attens_type}/input.jpg", original_img)
        elif args.attens_type == "TDE":
            prs = hook_prs_logger(model, device)
            prs.reinit()
            with torch.no_grad():
                image_embeddings = model.encode_image(image.to(device), attn_method="head", normalize=False)
                text_features = model.encode_text(text_tokens)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                attentions, mlps = prs.finalize(representation=image_embeddings)
            num_tokens= attentions.shape[2] - 1
            attentions = attentions.sum(dim=(1,3))
            tokens_effect = attentions[:,1:,:] + (torch.sum(torch.mean(mlps[:,:,:],dim=0),dim=0)/num_tokens + torch.mean(attentions[:,0,:],dim=0)/num_tokens).repeat(attentions.shape[0],num_tokens,1)
            # Calculate TDE
            logits_token = 100 * torch.matmul(tokens_effect, text_features.T) # [batch_size, num_tokens,2]
            p_fg = torch.sigmoid(logits_token)   # Independent probability of each token for class
            
            # Compute weights for background embedding
            w_fg = p_fg.clone()
            w_fg[w_fg <= 0] = 0        # Only tokens with high enough P(z|c_t) are considered
            
            c_z = tokens_effect * w_fg
            c_z = c_z / c_z.norm(dim=1, keepdim=True)            
            logits_c_z = 100 * torch.matmul(c_z, text_features.T)
            
            tde_base = logits_c_z

            size = int(tde_base.shape[1] ** 0.5)  # Square root of patch count
            sample_foreground = tde_base[0, :, 0].reshape(size, size).cpu().numpy()  # [size, size]
            results = torch.from_numpy(sample_foreground).unsqueeze(0).unsqueeze(0)  # [1, 1, size, size]
            # Interpolate to image size
            patch_size = 14  # Adjust based on model
            Res = F.interpolate(results, scale_factor=patch_size, mode="bilinear")  # [1, 1, H, W]
            Res = torch.clip(Res, 0, Res.max())  # Clip negative values
            Res = (Res - Res.min()) / (Res.max() - Res.min())  # Normalize to [0, 1]

            # Generate binary mask
            ret = Res.mean()
            Res_1 = Res.gt(ret).type(Res.type())  # Foreground mask

            # Save mask
            if True:  # Can change to if args.save_img:
                os.makedirs(f"{args.results_dir}/{args.attens_type}", exist_ok=True)
                mask = Res_1[0].squeeze().cpu().numpy()  # [H, W]
                mask = (255 * mask).astype("uint8")
                cv2.imwrite(f"{args.results_dir}/{args.attens_type}/mask.jpg", mask)

                # Generate and save heatmap
                # relevance = F.interpolate(Res, size=(224, 224), mode="bilinear")  # Assuming image size is 224x224
                # relevance = relevance[0].squeeze().cpu().numpy()  # [H, W]
                relevance = Res.squeeze().cpu().numpy()  # [H, W]
                if np.isnan(relevance).any() or np.isinf(relevance).any():
                    relevance = np.nan_to_num(relevance, nan=0.0, posinf=1.0, neginf=0.0)
                max_val = relevance.max()
                if max_val > 1e-6:
                    relevance = np.clip(255.0 * relevance / max_val, 0, 255.0).astype(np.uint8)
                else:
                    relevance = np.zeros_like(relevance, dtype=np.uint8)
                high = cv2.applyColorMap(relevance, cv2.COLORMAP_JET)  # BGR format
                cv2.imwrite(f"{args.results_dir}/{args.attens_type}/heatmap.jpg", high)
                
                # Generate heatmap overlaid on original image
                original_img = cv2.imread(image_path)
                if original_img is not None:
                    original_img = cv2.resize(original_img, (224, 224))

                    # Check high dimensions, resize if inconsistent
                    if high.shape[:2] != original_img.shape[:2]:
                        high = cv2.resize(high, (original_img.shape[1], original_img.shape[0]))

                    # Check number of channels, convert to 3 channels if high is single channel
                    if len(high.shape) == 2 or high.shape[2] == 1:
                        high = cv2.cvtColor(high, cv2.COLOR_GRAY2BGR)

                    # Now dimensions and channels are consistent, can overlay
                    overlay = cv2.addWeighted(original_img, 1 - alpha, high, alpha, 0)
                    
                    # Save overlay image
                    overlay_path = f"{args.results_dir}/{args.attens_type}/overlay.jpg"
                    # Also save original resolution version
                    cv2.imwrite(f"{args.results_dir}/{args.attens_type}/overlay_original_size.jpg", overlay)
                                    
                    # Increase output image resolution
                    high_res_size = (1024, 1024)  # Increase to higher resolution
                    high_res_overlay = cv2.resize(overlay, high_res_size, interpolation=cv2.INTER_CUBIC)
                    
                    # Save high-resolution image
                    cv2.imwrite(overlay_path, high_res_overlay, [cv2.IMWRITE_JPEG_QUALITY, 95])
                    logger.info(f"Heatmap overlay image saved to: {overlay_path} (resolution: {high_res_size[0]}x{high_res_size[1]})")

                    # Save original image to results folder
                    original_img = cv2.imread(image_path)
                    if original_img is not None:
                        high_res_size = (1024, 1024)
                        original_img = cv2.resize(original_img, high_res_size, interpolation=cv2.INTER_CUBIC)
                        cv2.imwrite(f"{args.results_dir}/{args.attens_type}/input.jpg", original_img)
            
    #     # --------------------------------------------------image illusion score computation
    #     import class_text
    #     class_pool = class_text.DATASET_CLASS_NAMES["imagenet"]
    #     text_features_pool = []
    #     for class_name in class_pool:
    #         # prompt = f"a photo of a {class_name}."
    #         prompt = f"a photo of a {class_name}, in the water."
    #         # prompt = f"a photo of a {class_name}, with trees."
    #         # prompt = f"a photo of a {class_name}, in a forest, with dense vegetation and massy stones on riverbank."
    #         # prompt = f"a photo of a {class_name}, in a foggy swamp, wading through murky water."
    #         # prompt = f"a photo of a {class_name}, in the zoo with lush green grass."
    #         with torch.no_grad():
    #             text_tokens = tokenizer(prompt).to(device)
    #             text_features = model.encode_text(text_tokens)
    #             text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    #             text_features_pool.append(text_features)
    #     text_features_pool = torch.cat(text_features_pool, dim=0)
    #     # text = ["a photo of forest."]
    #     # text_tokens = tokenizer(text).to(device)
    #     # text_features = model.encode_text(text_tokens)
    #     # text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        
    #     # cos_sim = F.cosine_similarity(text_features_pool, text_features, dim=-1)
    #     # top_10_classes = torch.argsort(cos_sim, dim=-1, descending=True)[:1000]
    #     # for i, class_idx in enumerate(top_10_classes):
    #     #     class_idx = class_idx.item() 
    #     #     print(f"{i+1}. {class_pool[class_idx]} ({cos_sim[class_idx].item():.2f})")  
        
    #     with torch.no_grad():
    #         image_features = model.encode_image(preprocess(Image.open("others/prompt/cat_water.jpg")).unsqueeze(0).to(device))
    #         image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    #     # Compute PRS
    #     prs = 100 * torch.matmul(image_features, text_features_pool.T)
        
    except Exception as e:
        logger.exception(f"Execution failed: {e}")

if __name__ == "__main__":
    parser = get_parser_info()
    args = parser.parse_args()
    set_seed(42)
    main(args)