from numpy import real
import torch
import glob
import os
from torchmetrics.functional.image import structural_similarity_index_measure
from tqdm import tqdm
from monai_fid import get_features_2p5d
import nibabel as nib
from monai.metrics.fid import FIDMetric


import sys
sys.path.append("./metrics/radimagenet-models")

from radimagenet_models.models.resnet import ResNet50


RELATIVE_DISTURBANCE = False


class NibsDataset(torch.utils.data.Dataset):
    def __init__(self, ct_dir, recon_dir, mask_dir):
        self.ct_files = sorted(glob.glob(os.path.join(ct_dir, "*.nii.gz")))
        self.recon_files = sorted(glob.glob(os.path.join(recon_dir, "*.nii.gz")))
        self.mask_files = sorted(glob.glob(os.path.join(mask_dir, "*.nii.gz")))

        assert len(self.ct_files) == len(self.recon_files) == len(self.mask_files), "Mismatch in number of files."

    def __len__(self):
        return len(self.ct_files)

    def __getitem__(self, idx):
        ct_img = nib.load(self.ct_files[idx]).get_fdata()
        recon_img = nib.load(self.recon_files[idx]).get_fdata()
        mask_img = nib.load(self.mask_files[idx]).get_fdata()

        ct_img = torch.from_numpy(ct_img).unsqueeze(0).float()  # (1, D, H, W)
        recon_img = torch.from_numpy(recon_img).unsqueeze(0).float()  # (1, D, H, W)
        mask_img = torch.from_numpy(mask_img).unsqueeze(0).float()  # (1, D, H, W)

        return ct_img, recon_img, mask_img


@torch.no_grad()
def compute_masked_metrics_3d(ct_dir, recon_dir, mask_dir):
    # 1. Setup

    ct_files = sorted(glob.glob(os.path.join(ct_dir, "*.nii.gz")))
    recon_files = sorted(glob.glob(os.path.join(recon_dir, "*.nii.gz")))
    mask_files = sorted(glob.glob(os.path.join(mask_dir, "*.nii.gz")))

    dataset = NibsDataset(ct_dir, recon_dir, mask_dir)

    print(f"Found {len(dataset)} files. Processing 3D volumes...")

    # 2. Initialize Accumulators
    total_masked_sse = 0.0      # Sum of Squared Errors (masked regions only)
    total_masked_sae = 0.0      # Sum of Absolute Errors (masked regions only)
    total_masked_voxels = 0     # Total count of valid voxels (1s in mask)
    
    total_ssim_sum = 0.0        # Sum of average SSIMs per sample
    total_samples = 0           # Total number of 3D volumes processed
    
    total_volume_size = 0       # To compute percentage coverage (D*H*W)

    real_xy = []
    real_zx = []
    real_yz = []
    synth_xy = []
    synth_zx = []
    synth_yz = []

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    feature_network = ResNet50()
    feature_network.load_state_dict(
        torch.load("./metrics/RadImageNet-ResNet50_notop.pth", map_location="cpu")
    )
    feature_network.to(device)
    feature_network.eval()
    fid = FIDMetric()

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=16, shuffle=False, num_workers=4,
        pin_memory=True, drop_last=False, persistent_workers=True
    )

    for batch in tqdm(dataloader, desc="Processing Batches"):
        original, recon, mask = batch  # Each is (B, 1, D, H, W)
        original = original.to(device)
        recon = recon.to(device)
        mask = mask.to(device)

        # Fill mask region with random noise to avoid feature extractor issues
        if RELATIVE_DISTURBANCE:
            noise = torch.randn_like(recon)
            recon = recon * (1 - mask) + noise * mask

        if original.ndim == 4:
            original = original.unsqueeze(1)
            recon = recon.unsqueeze(1)
            mask = mask.unsqueeze(1)

        # Clip to Lung Window
        low, up = -600 - 1500/2, -600 + 1500/2
        original = torch.clamp(original, low, up)
        recon = torch.clamp(recon, low, up)

        # Back to [0, 1]
        original = (original - low) / (up - low)
        recon = (recon - low) / (up - low)

        ################ PROCESS METRICS ################
        real_feats = get_features_2p5d(
            original,
            mask,
            feature_network,
            drop_empty=True,
            empty_threshold=0.5, # Mask 1 > 0.5 > 0
            center_slices=False,
            xy_only=False,
            use_min_max_normalization=False, # Already normalized above
        )
        real_xy.append(real_feats[0].detach().cpu())
        real_zx.append(real_feats[1].detach().cpu())
        real_yz.append(real_feats[2].detach().cpu())
        synth_feats = get_features_2p5d(
            recon,
            mask,
            feature_network,
            drop_empty=True,
            empty_threshold=0.5, # Mask 1 > 0.5 > 0
            center_slices=False,
            xy_only=False,
            use_min_max_normalization=False, # Already normalized above
        )
        synth_xy.append(synth_feats[0].detach().cpu())
        synth_zx.append(synth_feats[1].detach().cpu())
        synth_yz.append(synth_feats[2].detach().cpu())

        batch_size = original.size(0)
        
        # --- A. Masked MSE & RMSE ---
        # 1. Compute difference
        diff = original - recon
        
        # 2. Zero out the difference outside the mask
        masked_diff = diff * mask
        
        # 3. Square and sum (Sum of Squared Errors)
        batch_sse = torch.sum(masked_diff ** 2).item()
        batch_sae = torch.sum(torch.abs(masked_diff)).item()
        
        # 4. Count valid voxels in this batch
        batch_mask_count = torch.sum(mask).item()
        
        total_masked_sse += batch_sse
        total_masked_sae += batch_sae
        total_masked_voxels += batch_mask_count

        # --- B. Masked SSIM ---
        # We calculate the full SSIM map, then mask the map itself.
        # sigma=1.5 is standard, kernel_size=11 is standard.
        # data_range=1.0 assumes normalized data. Change to 255.0 if needed.
        _, ssim_map  = structural_similarity_index_measure(
            recon, 
            original, 
            data_range=1.0, 
            return_full_image=True
        )
        
        # # Mask the SSIM map
        masked_ssim_map = ssim_map * mask
        
        # # Compute mean SSIM per sample (only over masked region)
        # # We sum the scores per sample and divide by the mask count per sample
        # # flattens (N, C, D, H, W) -> (N, -1) to sum over spatial dims
        ssim_sums = masked_ssim_map.flatten(1).sum(1)
        mask_counts = mask.flatten(1).sum(1)
        
        # Avoid division by zero
        batch_avg_ssims = ssim_sums / (mask_counts + 1e-8)
        
        total_ssim_sum += torch.sum(batch_avg_ssims).item()
        
        # --- C. Volume Stats ---
        total_samples += batch_size
        
        # Track total theoretical volume (masked + unmasked) for percentage calc
        # original.numel() includes Batch and Channel, so we divide by batch size
        sample_vol_size = original[0].numel() 
        total_volume_size += (sample_vol_size * batch_size)

        # Cleanup
        del original, recon, mask
        # torch.cuda.empty_cache() # Uncomment if VRAM is very tight

    # 4. Final Aggregation
    if total_samples == 0 or total_masked_voxels == 0:
        print("No samples or empty masks processed.")
        return

    # Compute FID
    real_xy = torch.cat(real_xy, dim=0)
    real_zx = torch.cat(real_zx, dim=0)
    real_yz = torch.cat(real_yz, dim=0)
    synth_xy = torch.cat(synth_xy, dim=0)
    synth_zx = torch.cat(synth_zx, dim=0)
    synth_yz = torch.cat(synth_yz, dim=0)
    fid_res_xy = fid(synth_xy, real_xy).item()
    fid_res_xz = fid(synth_zx, real_zx).item()
    fid_res_yz = fid(synth_yz, real_yz).item()
    print(f"FID XY: {fid_res_xy:.4f}, FID XZ: {fid_res_xz:.4f}, FID YZ: {fid_res_yz:.4f}")
    print(f"FID Mean: {(fid_res_xy + fid_res_xz + fid_res_yz)/3:.4f}")

    # Global MSE (Sum of all masked errors / Sum of all masked voxels)
    global_mse = total_masked_sse / total_masked_voxels
    global_mae = total_masked_sae / total_masked_voxels
    global_rmse = global_mse ** 0.5
    
    # Global SSIM (Mean of sample means)
    global_ssim = total_ssim_sum / total_samples

    # Average Volume metrics
    avg_mask_voxels = total_masked_voxels / total_samples
    avg_mask_percentage = (total_masked_voxels / total_volume_size) * 100

    # 5. Output
    print("-" * 40)
    print(f"Processed {total_samples} samples across {len(dataset)} files.")
    print("-" * 40)
    print(f"Metrics computed on MASKED regions only:")
    print(f"Masked MSE:          {global_mse:.6f}")
    print(f"Masked RMSE:         {global_rmse:.6f}")
    print(f"Masked RMSE (in HU): {(global_rmse * (up - low)):.6f}")
    print(f"Masked MAE:          {global_mae:.6f}")
    print(f"Masked MAE (in HU):  {(global_mae * (up - low)):.6f}")
    print(f"Masked SSIM:         {global_ssim:.6f}")
    print("-" * 40)
    print(f"Mask Volume Statistics:")
    print(f"Avg Mask Size:    {avg_mask_voxels:.0f} voxels")
    print(f"Avg Coverage:     {avg_mask_percentage:.2f}% of total volume")
    print("-" * 40)

if __name__ == "__main__":
    ct_dir = "metrics/nibs/nibs_ct_128"
    recon_dir = "metrics/NV-Generate-CTMR/output/nlst"
    mask_dir = "metrics/nibs/nibs_masks_128_fixed"
    compute_masked_metrics_3d(ct_dir, recon_dir, mask_dir)
