from .compressor import Compressor
import torch.distributed as dist

def create_compression_hook(layer_idx, layer_config):
    compressor = None
    
    def compression_hook(module, input_tensor, output):
        nonlocal compressor
        
        # 1) Convert the pipeline's output into a tuple, so we don't accidentally index a single tensor or an immutable_list.
        if not isinstance(output, tuple):
            output = tuple(output)

        # 2) If output[0] is still some kind of container (e.g. immutable_list), unwrap it.
        #    In many pipeline configs, the first element might be [tensor(...)].
        #    Check if the first item is a list-like object that needs unwrapping.
        if len(output) > 0 and not hasattr(output[0], 'shape'):
            # If output[0] is something like an immutable_list or a nested list, 
            # try unwrapping the tensor from position [0].
            if len(output[0]) > 0 and hasattr(output[0][0], 'shape'):
                # Build a new tuple with the unwrapped tensor as the first element
                output = (output[0][0],) + output[1:]
            else:
                # If it’s still not a tensor, print or handle accordingly:
                print(f"Rank {dist.get_rank()}: Could not unwrap the first output properly.")
                return output  # Fallback early if you can’t proceed.
        
        # Now we should have output[0] as a torch.Tensor.
        # If there’s nothing to compress, just return the tuple.
        if len(output) == 0:
            return output
        
        # print("rank", dist.get_rank(), "input_tensor", input_tensor)
        # print("rank", dist.get_rank(), "output", output)
        
        # 3) Initialize the compressor once (based on the shape of output[0]).
        if compressor is None:
            compressor = Compressor(
                input_shape=output[0].shape,
                forward=layer_config['forward'],
                forward_params=layer_config['forward-params'],
                backward=layer_config['backward'],
                backward_params=layer_config['backward-params'],
                forward_EF=layer_config['forward-EF'],
                backward_EF=layer_config['backward-EF'],
                forward_EF_method=layer_config['forward-EF-method'],
                backward_EF_method=layer_config['backward-EF-method']
            )
        
        # 4) Retrieve indices if the module has them (e.g. for selective compression).
        indices = getattr(module, 'current_indices', None)
        
        # 5) Compress the first element of the output tuple.
        compressed_first = compressor(output[0], indices=indices)
        
        # Keep the rest of the output unmodified (logits, hidden_states, etc.).
        return (compressed_first,) + output[1:]
    
    return compression_hook