"""
Utility functions for diffusion model training and evaluation.
Common code extracted from diff_*.py files to reduce duplication.
"""
import os
import torch
import torch.nn.functional as F
from torch.amp import autocast
import torchvision.transforms as T
from torchvision.utils import make_grid
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score
from torchmetrics.image import StructuralSimilarityIndexMeasure
from data.pairs_dataset.dataset import PERCEPTUAL_FEATURES, NUM_TREATMENTS


# Constants
EPSILON_DELTA_T = 1e-8
EPSILON_PROB = 1e-6
EPSILON_IPW = 1e-7
MIN_PROB_CLAMP = 1e-4
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]


def safe_tensor_log(value, fallback_value, device):
    """Handle NaN tensor logging consistently."""
    if torch.isnan(value):
        print(f"WARN: NaN detected, using fallback value {fallback_value}")
        return torch.tensor(fallback_value, device=device)
    return value


def calculate_ssim(generated_images, target_images, device):
    """Calculate SSIM between generated and target images."""
    try:
        ssim_metric = StructuralSimilarityIndexMeasure(data_range=2.0).to(device)
        ssim_score = ssim_metric(generated_images, target_images)
        return safe_tensor_log(ssim_score, 0.0, device)
    except Exception as e:
        print(f"WARN: SSIM calculation failed: {e}")
        return torch.tensor(0.0, device=device)


def extract_batch_data(batch, device):
    """Extract and move batch data to device consistently."""
    data = {}
    for key, value in batch.items():
        if isinstance(value, torch.Tensor):
            data[key] = value.to(device)
        else:
            data[key] = value
    return data


def calculate_predicted_x0(pred_noise, timesteps, noisy_latents, scheduler):
    """Calculate predicted x0 from noise prediction."""
    alphas_cumprod = scheduler.alphas_cumprod.to(pred_noise.device)
    sqrt_alpha_prod = alphas_cumprod[timesteps]**0.5
    sqrt_alpha_prod = sqrt_alpha_prod.flatten()[:,None,None,None]
    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps])**0.5
    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()[:,None,None,None]
    pred_original_sample = (noisy_latents - sqrt_one_minus_alpha_prod * pred_noise) / (sqrt_alpha_prod + EPSILON_DELTA_T)
    return pred_original_sample


def calculate_feature_loss(self, pred_noise, timesteps, noisy_latents, later_features, feature_predictors, 
                          vae, latent_scale, trainer_precision, device):
    """Calculate feature loss using pre-trained predictors."""
    feature_loss = torch.tensor(0.0, device=device)
    
    if len(feature_predictors) == 0:
        return feature_loss
        
    try:
        pred_latents = calculate_predicted_x0(pred_noise, timesteps, noisy_latents, self.scheduler)
        if torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any():
            print("WARN: NaN/Inf in pred_latents for feature loss")
            return feature_loss
            
        generated_images = vae.decode(pred_latents / latent_scale).sample
        if torch.isnan(generated_images).any():
            print("WARN: NaN in generated_images for feature loss")
            return feature_loss
            
        f_loss_total = 0.0
        f_counts = 0
        gen_images_0_to_1 = (generated_images * 0.5) + 0.5
        imagenet_norm_transform = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
        
        for feat_name, predictor in feature_predictors.items():
            try:
                with autocast('cuda', enabled=trainer_precision.startswith("16") or trainer_precision.startswith("bf16")):
                    feat_idx = PERCEPTUAL_FEATURES.index(feat_name)
                    gt = later_features[:, feat_idx].long().to(device)
                    valid_mask = (gt != -1) & (gt >= 0)
                    
                    if valid_mask.sum() > 0:
                        head_layer = getattr(predictor, 'head', 
                                           getattr(predictor, 'fc', 
                                                 getattr(predictor, 'classifier', None)))
                        processed_gen = imagenet_norm_transform(gen_images_0_to_1.to(head_layer.weight.dtype))
                        pred_logits = predictor(processed_gen.to(head_layer.weight.dtype))
                        
                        if not torch.isnan(pred_logits).any():
                            num_classes = pred_logits.shape[-1]
                            valid_mask &= (gt < num_classes)
                            if valid_mask.sum() > 0:
                                current_loss = F.cross_entropy(pred_logits[valid_mask].float(), gt[valid_mask])
                                if not torch.isnan(current_loss):
                                    f_loss_total += current_loss
                                    f_counts += 1
            except (ValueError, IndexError, KeyError):
                print(f"Warn: Skipping feature {feat_name} in feature loss calculation.")
                continue
            except Exception as e:
                print(f"Err predictor {feat_name}: {e}")
                
        if f_counts > 0:
            feature_loss = f_loss_total / f_counts
            
    except Exception as e:
        print(f"Err feature loss block: {e}")
        
    return safe_tensor_log(feature_loss, 0.0, device)


def calculate_ite_metrics(generated_images, target_images, earlier_features, later_features,
                         feature_predictors, trainer_precision, device, treat_mask=None, ipw=None):
    """Calculate ITE metrics for feature predictors."""
    if len(feature_predictors) == 0:
        return {}
    
    # Check if generation failed (images are nearly identical, indicating fallback occurred)
    # Use a small tolerance to handle floating-point precision issues
    # Ensure both tensors have the same dtype for comparison
    gen_images_float = generated_images.float()
    target_images_float = target_images.float()
    images_nearly_equal = torch.allclose(gen_images_float, target_images_float, rtol=1e-5, atol=1e-6)
    if images_nearly_equal:
        print("WARN ITE: Generation failed (images nearly identical), returning penalty metrics")
        penalty_metrics = {}
        for feat_name in feature_predictors.keys():
            penalty_metrics[f'ite_{feat_name}'] = 3.0
            penalty_metrics[f'model_ite_{feat_name}'] = 3.0
            
            # Add IPW-weighted penalty metrics if IPW is provided
            if ipw is not None:
                penalty_metrics[f'ipw_ite_{feat_name}'] = 3.0
                penalty_metrics[f'ipw_model_ite_{feat_name}'] = 3.0
            
            # Add treated subset penalty metrics if treat_mask is provided
            if treat_mask is not None:
                penalty_metrics[f'ite_{feat_name}_treated'] = 3.0
                penalty_metrics[f'model_ite_{feat_name}_treated'] = 3.0
                
                if ipw is not None:
                    penalty_metrics[f'ipw_ite_{feat_name}_treated'] = 3.0
                    penalty_metrics[f'ipw_model_ite_{feat_name}_treated'] = 3.0
                    
        return penalty_metrics
        
    metrics = {}
    min_expected_earlier_dim = max([PERCEPTUAL_FEATURES.index(f) for f in feature_predictors.keys() 
                                   if f in PERCEPTUAL_FEATURES] + [-1]) + 1
    
    if earlier_features.shape[1] < min_expected_earlier_dim:
        return {}
        
    imagenet_norm_transform = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    
    for feat_name, predictor in feature_predictors.items():
        try:
            feat_idx = PERCEPTUAL_FEATURES.index(feat_name)
            earlier_feat_gt = earlier_features[:, feat_idx].long().to(device)
            later_feat_gt = later_features[:, feat_idx].long().to(device)
            
            base_mask = (earlier_feat_gt != -1) & (later_feat_gt != -1) & (later_feat_gt != -1.0)
            if base_mask.sum() == 0:
                continue
                
            # Prepare images for predictor
            gen_images_0_to_1 = (generated_images * 0.5) + 0.5
            target_images_0_to_1 = (target_images * 0.5) + 0.5
            
            with torch.no_grad(), autocast('cuda', enabled=trainer_precision.startswith("16") or trainer_precision.startswith("bf16")):
                head_layer = getattr(predictor, 'head', 
                                   getattr(predictor, 'fc', 
                                         getattr(predictor, 'classifier', None)))
                if head_layer is None:
                    print(f"WARN ITE: Could not find head layer for predictor {feat_name}")
                    continue
                    
                processed_gen_images = imagenet_norm_transform(gen_images_0_to_1.to(head_layer.weight.dtype))
                processed_target_images = imagenet_norm_transform(target_images_0_to_1.to(head_layer.weight.dtype))
                pred_feat_logits = predictor(processed_gen_images)
                target_feat_logits = predictor(processed_target_images)
                
            if torch.isnan(pred_feat_logits).any() or torch.isnan(target_feat_logits).any():
                print(f"WARN ITE {feat_name}: NaN in predictor logits.")
                continue
                
            # Calculate ITE metrics
            pred_vals = pred_feat_logits.argmax(dim=-1).float()
            target_vals = target_feat_logits.argmax(dim=-1).float()
            gt_vals = earlier_feat_gt.float()
            
            # Deltas for all valid samples
            pred_delta_all = pred_vals[base_mask] - gt_vals[base_mask]
            gt_delta_all = later_feat_gt.float()[base_mask] - gt_vals[base_mask]
            model_delta_all = target_vals[base_mask] - gt_vals[base_mask]
            
            ite_err_all = torch.abs(pred_delta_all - gt_delta_all)
            model_ite_err_all = torch.abs(pred_delta_all - model_delta_all)
            
            metrics[f'ite_{feat_name}'] = ite_err_all.mean().item()
            metrics[f'model_ite_{feat_name}'] = model_ite_err_all.mean().item()
            
            # IPW-weighted metrics
            if ipw is not None:
                ipw_vals = ipw[base_mask].to(ite_err_all.device)
                w_norm = ipw_vals / ipw_vals.sum().clamp(min=EPSILON_IPW)
                metrics[f'ipw_ite_{feat_name}'] = (ite_err_all * w_norm).sum().item()
                metrics[f'ipw_model_ite_{feat_name}'] = (model_ite_err_all * w_norm).sum().item()
            
            # Treated subset metrics
            if treat_mask is not None:
                treated_mask = base_mask & treat_mask
                if treated_mask.sum() > 0:
                    pred_delta_tr = pred_vals[treated_mask] - gt_vals[treated_mask]
                    gt_delta_tr = later_feat_gt.float()[treated_mask] - gt_vals[treated_mask]
                    model_delta_tr = target_vals[treated_mask] - gt_vals[treated_mask]
                    
                    ite_err_tr = torch.abs(pred_delta_tr - gt_delta_tr)
                    model_ite_err_tr = torch.abs(pred_delta_tr - model_delta_tr)
                    
                    metrics[f'ite_{feat_name}_treated'] = ite_err_tr.mean().item()
                    metrics[f'model_ite_{feat_name}_treated'] = model_ite_err_tr.mean().item()
                    
                    if ipw is not None:
                        ipw_tr = ipw[treated_mask].to(ite_err_tr.device)
                        w_tr = ipw_tr / ipw_tr.sum().clamp(min=EPSILON_IPW)
                        metrics[f'ipw_ite_{feat_name}_treated'] = (ite_err_tr * w_tr).sum().item()
                        metrics[f'ipw_model_ite_{feat_name}_treated'] = (model_ite_err_tr * w_tr).sum().item()
                        
        except (ValueError, IndexError, KeyError) as e:
            print(f"WARN ITE: Skipping feature {feat_name} due to error: {e}")
            continue
        except Exception as e:
            print(f"WARN ITE: Error calculating ITE for {feat_name}: {e}")
            
    return metrics


def log_qualitative_images(pipe, input_imgs, target_imgs, prompts, logger, current_epoch, 
                          suffix="", strength=0.75, guidance_scale=7.5, device='cuda'):
    """Log qualitative image generation results."""
    if pipe is None:
        print("Warn: Pipe not ready for qualitative logging.")
        return
        
    num_images = input_imgs.shape[0]
    print(f"Generating qualitative images ({num_images} samples)...")
    
    try:
        transform_to_pil = T.ToPILImage()
        input_pil = [transform_to_pil(((img * 0.5) + 0.5).cpu().float()) for img in input_imgs]
        
        with torch.no_grad(), autocast('cuda', enabled=False):
            generated_pil = pipe(prompt=prompts, image=input_pil, strength=strength, 
                               guidance_scale=guidance_scale, num_inference_steps=50).images
        
        to_tensor_transform = T.ToTensor()
        generated_tensors = torch.stack([to_tensor_transform(img) for img in generated_pil]).to(device)
        
        input_grid = make_grid(input_imgs.cpu().detach(), normalize=True, value_range=(-1, 1))
        target_grid = make_grid(target_imgs.cpu().detach(), normalize=True, value_range=(-1, 1))
        generated_grid = make_grid(generated_tensors.cpu().detach())
        
        logger.experiment.add_image(f"Input{suffix}", input_grid, current_epoch)
        logger.experiment.add_image(f"Target{suffix}", target_grid, current_epoch)
        logger.experiment.add_image(f"Generated{suffix}", generated_grid, current_epoch)
        
        prompts_text = "\n".join([f"Img {i+1}: {p}" for i, p in enumerate(prompts)])
        logger.experiment.add_text(f"Val/Prompts{suffix}", prompts_text, current_epoch)
        
        print("Qualitative images logged.")
        
    except Exception as e:
        print(f"Warn: Qualitative logging failed: {e}")


def calculate_discriminator_metrics(treatment_logits, interval_labels):
    """Calculate discriminator accuracy and AUC metrics."""
    probs = torch.sigmoid(treatment_logits)
    preds = (probs >= 0.5).int()
    labels_np = interval_labels.cpu().numpy()
    preds_np = preds.cpu().numpy()
    probs_np = probs.cpu().numpy()
    
    # Flatten for overall accuracy calculation
    labels_flat = labels_np.flatten()
    preds_flat = preds_np.flatten()
    
    acc = accuracy_score(labels_flat, preds_flat)
    
    try:
        aucs = []
        for i in range(NUM_TREATMENTS):
            if len(np.unique(labels_np[:, i])) > 1:
                aucs.append(roc_auc_score(labels_np[:, i], probs_np[:, i]))
        auc = np.mean(aucs) if aucs else 0.0
    except Exception as e:
        print(f"WARN: Could not calculate AUC: {e}")
        auc = 0.0
        
    return acc, auc


def aggregate_test_metrics(test_outputs):
    """Aggregate test metrics across batches."""
    if not test_outputs:
        print("No valid outputs collected during testing.")
        return {}
        
    valid_outputs = [out for out in test_outputs if out is not None]
    if not valid_outputs:
        print("All test steps failed or returned None.")
        return {}
        
    metric_keys = {k for out in valid_outputs for k in out.keys()} - {'batch_size'}
    total_samples = sum(out.get('batch_size', 0) for out in valid_outputs)
    
    print(f"\n--- Aggregating Test Results ({len(valid_outputs)} Batches, {total_samples:.0f} Samples) ---")
    
    aggregated_metrics = {}
    for key in sorted(list(metric_keys)):
        key_valid_outputs = [out for out in valid_outputs if key in out]
        if not key_valid_outputs:
            continue
            
        total_value = sum(out[key] * out['batch_size'] for out in key_valid_outputs)
        total_weight = sum(out['batch_size'] for out in key_valid_outputs)
        aggregated_metrics[key] = total_value / total_weight if total_weight > 0 else float('nan')
        
    return aggregated_metrics, total_samples


def save_test_results(aggregated_metrics, total_samples, results_file_path, ckpt_path=None):
    """Save test results to file."""
    print("--- Final Test Metrics ---")
    results_str = [
        f"Checkpoint: {ckpt_path or 'N/A'}", 
        f"Total Samples: {total_samples:.0f}", 
        "---"
    ]
    
    for key, value in aggregated_metrics.items():
        print(f"{key}: {value:.6f}")
        results_str.append(f"{key}: {value:.6f}")
        
    if results_file_path:
        print(f"Saving results to {results_file_path}...")
        try:
            dirpath = os.path.dirname(results_file_path)
            if dirpath:
                os.makedirs(dirpath, exist_ok=True)
            with open(results_file_path, 'w') as f:
                f.write("\n".join(results_str))
            print("Results saved.")
        except Exception as e:
            print(f"Error saving results: {e}")