"""
Data utilities for diffusion model training.
Common collate functions and data handling extracted from diff_*.py files.
"""
import torch
from torch.nn.utils.rnn import pad_sequence


def create_base_collate_fn():
    """Create basic collate function for diffusion training."""
    def collate_fn(batch):
        input_images = torch.stack([b["input_image"] for b in batch])
        target_images = torch.stack([b["target_image"] for b in batch])
        prompts = [b["prompt"] for b in batch]
        earlier_features = torch.stack([b["earlier_features"] for b in batch])
        later_features = torch.stack([b["later_features"] for b in batch])
        
        return {
            "input_image": input_images,
            "target_image": target_images, 
            "prompt": prompts,
            "earlier_features": earlier_features,
            "later_features": later_features
        }
    return collate_fn


def create_adversarial_collate_fn():
    """Create collate function for adversarial training."""
    def collate_fn(batch):
        # Base data
        collated = create_base_collate_fn()(batch)
        
        # Add adversarial-specific data
        collated["interval_labels"] = torch.stack([b["interval_labels"] for b in batch])
        
        # Historical data for context encoder
        covs_hist = [b["cov_seq"] for b in batch]
        trts_hist = [b["trt_seq"] for b in batch]
        lengths_hist = torch.tensor([seq.size(0) for seq in covs_hist], dtype=torch.long)
        
        collated.update({
            "cov_seq": pad_sequence(covs_hist, batch_first=True, padding_value=0.0),
            "trt_seq": pad_sequence(trts_hist, batch_first=True, padding_value=0.0),
            "lengths": lengths_hist,
            "delta_t": torch.tensor([b["delta_t"] for b in batch], dtype=torch.float).unsqueeze(1),
        })
        
        # Process side information
        side_list = [b["side"] for b in batch]
        collated["side"] = torch.tensor(
            [1.0 if s.upper() == "R" else 0.0 for s in side_list], 
            dtype=torch.float
        ).unsqueeze(1)
        
        return collated
    return collate_fn


def create_ipw_collate_fn(propensity_model_uses_images=False):
    """Create collate function for IPW training."""
    def collate_fn(batch):
        # Base data  
        collated = create_base_collate_fn()(batch)
        
        # Add IPW-specific data
        collated["interval_labels"] = torch.stack([b["interval_labels"] for b in batch])
        
        try:
            # Historical data for propensity model
            collated.update({
                "cov_seq": pad_sequence([b["cov_seq"] for b in batch], batch_first=True, padding_value=0.0),
                "trt_seq": pad_sequence([b["trt_seq"] for b in batch], batch_first=True, padding_value=0.0),
                "lengths": torch.tensor([b["cov_seq"].size(0) for b in batch], dtype=torch.long),
                "delta_t": torch.tensor([b["delta_t"] for b in batch], dtype=torch.float).unsqueeze(1),
            })
            
            # Process side information
            side_list = [b["side"] for b in batch]
            collated["side"] = torch.tensor(
                [1.0 if s.upper().startswith("R") else 0.0 for s in side_list],
                dtype=torch.float
            ).unsqueeze(1)
            
            # Image sequence if needed
            collated["image_seq"] = None
            if propensity_model_uses_images:
                if all("image_seq" in b and b["image_seq"] is not None for b in batch):
                    collated["image_seq"] = pad_sequence([b["image_seq"] for b in batch], 
                                                       batch_first=True, padding_value=0.0)
                else:
                    print("Warn: Propensity model uses images, but image_seq missing/None in batch.")
                    
        except KeyError as e:
            print(f"Error collating IPW history: Missing key {e}.")
            # Set keys to None so IPW calculation can gracefully skip
            for k in ["cov_seq", "trt_seq", "lengths", "delta_t", "side", "image_seq"]:
                collated[k] = None
        except Exception as ex:
            print(f"Unexpected error in IPW collate: {ex}")
            for k in ["cov_seq", "trt_seq", "lengths", "delta_t", "side", "image_seq"]:
                collated[k] = None
                
        return collated
    return collate_fn


def create_unified_collate_fn(include_adversarial=False, include_ipw=False, propensity_model_uses_images=False):
    """Create unified collate function that handles all cases."""
    def collate_fn(batch):
        # Start with base data
        collated = create_base_collate_fn()(batch)
        
        # Add labels if needed for adversarial or IPW
        if include_adversarial or include_ipw:
            collated["interval_labels"] = torch.stack([b["interval_labels"] for b in batch])
        
        # Add historical/context data if needed
        if include_adversarial or include_ipw:
            try:
                covs_hist = [b["cov_seq"] for b in batch]
                trts_hist = [b["trt_seq"] for b in batch]
                lengths_hist = torch.tensor([seq.size(0) for seq in covs_hist], dtype=torch.long)
                
                collated.update({
                    "cov_seq": pad_sequence(covs_hist, batch_first=True, padding_value=0.0),
                    "trt_seq": pad_sequence(trts_hist, batch_first=True, padding_value=0.0),
                    "lengths": lengths_hist,
                    "delta_t": torch.tensor([b["delta_t"] for b in batch], dtype=torch.float).unsqueeze(1),
                })
                
                # Process side information
                side_list = [b["side"] for b in batch]
                collated["side"] = torch.tensor(
                    [1.0 if s.upper() == "R" else 0.0 for s in side_list], 
                    dtype=torch.float
                ).unsqueeze(1)
                
                # Image sequence for IPW if needed
                collated["image_seq"] = None
                if include_ipw and propensity_model_uses_images:
                    if all("image_seq" in b and b["image_seq"] is not None for b in batch):
                        collated["image_seq"] = pad_sequence([b["image_seq"] for b in batch], 
                                                           batch_first=True, padding_value=0.0)
                        
            except (KeyError, Exception) as e:
                print(f"Error collating historical data: {e}")
                # Set to None so models can handle gracefully
                for k in ["cov_seq", "trt_seq", "lengths", "delta_t", "side", "image_seq"]:
                    collated[k] = None
                    
        return collated
    return collate_fn