from typing import Dict, List, cast
from logging import Logger as LoggerType
import torch
from torch import nn
import torch.utils
from torch.utils.data.dataloader import DataLoader
from rich.progress import (
    BarColumn,
    Progress,
    SpinnerColumn,
    TaskProgressColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
    MofNCompleteColumn
)

from prune.halpe import HALPE, TransformerHALPE
from prune.utils import compute_candidate_blocks, compute_candidate_blocks_pct_rank, select_least_important_globally, meanstd_normalize_tensor_list, min_max_normalize_tensor_list, select_candidates_option_a_normalized_final, select_candidates_option_b
from inference.inference import create_inference_metric
from utils.pruning_utils import PruningArguments
from utils.model_config import ModelMetadata
from utils.layer_utils import (LayerType, TransformerConfig, get_layer_config, 
                               resolve_instance_schema_map, TransformerLayerSchema)
from utils.logs import  SharedLogger
from prune.stage_runner import stage_forward
from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from utils.layer_utils import LayerSchema


def prune(model_data: ModelMetadata, prune_args: PruningArguments, 
          calibration_dataloader: DataLoader, device: str = "cpu"):
    columns = (SpinnerColumn(), "[progress.description]{task.description}", BarColumn(),
            TaskProgressColumn(), "Progress:", MofNCompleteColumn(),
            "Elapsed:", TimeElapsedColumn(), "Remaining:", TimeRemainingColumn())
    
    logger = SharedLogger.get_logger("prune")
    logger.debug(f"Prune function called with device: {device}, type: {type(device)}")
    logger.debug(f"Student model initial device: {model_data.student_model.device}")
    # Move model to the specified device
    model_data.student_model.gradient_checkpointing_disable()  # no grads anyway
    # model_data.student_model.to(device)    
    # logger.debug(f"Student model after moving to device: {model_data.student_model.device}")
    
    model_layers = resolve_instance_schema_map(
         model_data.student_model,
        model_data.schema)
    logger.info(f"Extracted {len(model_layers)} layers from the model with types: {[(name, layer_types.layer_name) for name, _, layer_types in model_layers]}")
    assert len(model_layers) > 0, f"No layers found in the model with one of schema {model_data.schema}. Please check the model and schema."
    # initialize pruners
    pruners: List[HALPE] = []
    for layer_name, layer, layer_schema in model_layers:
        layer_config = get_layer_config(model_data.config, layer_schema.layer_type, layer_name)
        if layer_schema.layer_type == LayerType.transformer:
            pruners.append(TransformerHALPE(layer, cast(TransformerConfig, layer_config), 
                                            cast(TransformerLayerSchema, layer_schema), device=device, use_chunking=prune_args.use_chunking, chunk_size=prune_args.chunk_size))
            logger.debug(f"Initialized TransformerHALPE for layer {layer_name} with schema {layer_schema}")
    logger.info(f"Initialized {len(pruners)} pruners for the model layers")

    # Create inference metric
    metric_type = "latency" if prune_args.use_latency else "sparsity"
    inference_metric = create_inference_metric(metric_type, prune_args.inference_speedup)
    
    origianl_inference = inference_metric.compute_original_inference(model_layers)
    logger.info(f"Original inference: {origianl_inference}")
    
    # Add safety checks
    if origianl_inference <= 0:
        raise ValueError("Original inference time must be positive")
    
    max_iterations = prune_args.max_iterations  # Prevent infinite loops
    iteration = 0
    
    while not inference_metric.is_target_speedup_achieved() and iteration < max_iterations:
        iteration += 1
        logger.info(f"Pruning iteration {iteration}")

        # report_param_extrema_and_finite(model_data.student_model)
            
        # get initial importances
        # initial_importances = []
        # with Progress(*columns) as progress:
        #     task = progress.add_task("[blue]Computing Initial Importances...", total=len(pruners))
        #     for i, pruner in enumerate(pruners):
        #         initial_importances.append(pruner.get_initial_importances(i))
        #         progress.update(task, advance=1)
        # initial_importances=torch.cat(initial_importances, dim=0)
        # normalize_importance_tensor(initial_importances, 0)
        # initial_importances = initial_importances.cpu()
        
        # Passing calibration data through the model to compute sensitivities
        logger.debug('Registering hooks to compute sensitivities and Hessian Matrix for each layer...')
        for i, pruner in enumerate(pruners):
            # Register hooks for computing sensitivities
            pruner.register_hook()
            # Initialize MicroPruner hooks to compute hessian matrix
            pruner.initialize_pruners()
            logger.debug(f'Registered hooks for layer {i}')
        logger.info('Passing calibration data through the model to compute layer sensitivities and Hessian Matrix for each layer...')
        
        model_data.student_model.config._attn_implementation = "eager"  # bypass sdpa/flash
        model_data.student_model.to(dtype=torch.bfloat16).eval()
        with torch.autocast("cuda", dtype=torch.bfloat16):
            stage_forward(model_data.student_model, 
                        calibration_dataloader, 
                        pruners, stage_size=4, start_layer=0, 
                        end_layer=len(pruners) - 1, 
                        device=device)
        
        initial_importances = []
        with Progress(*columns) as progress:
            task = progress.add_task("[blue]Computing Initial Importances...", total=len(pruners))
            for i, pruner in enumerate(pruners):
                pruner.move_layer_to_gpu()
                pruner.move_all_hessians_to_gpu()
                # initial_importances.append(pruner.get_initial_importances_l1_weight_magnitudes(i))
                initial_importances.append(pruner.get_initial_importances_l2_weight_magnitudes(i))
                # initial_importances.append(pruner.get_initial_importances_hessian_diagonal(i))
                pruner.move_all_hessians_to_cpu()
                pruner.move_layer_to_cpu()
                progress.update(task, advance=1)
        initial_importances=torch.cat(initial_importances, dim=0)
        # #normalize_importance_tensor(initial_importances, 0)
        # initial_importances[:, 0] = normalize_minmax(initial_importances[:, 0])
        initial_importances = initial_importances.cpu()
        
        
        # with Progress(*columns) as progress:
        #     batch_size = calibration_dataloader.batch_size or 1
        #     task = progress.add_task(
        #         "[blue]Collecting Information...",
        #         total=len(calibration_dataloader)
        #     )
        #     with torch.no_grad():
        #         model_data.student_model.eval()
        #         assert str(model_data.student_model.device) == device, f"Model is not on device {device}. Please set it to {device}."
        #         assert model_data.student_model.training is False, "Model is in training mode. Please set it to eval mode."
        #         with torch.inference_mode():
        #             for batch_idx, batch in enumerate(calibration_dataloader):
        #                 input_ids, attention_mask = batch
        #                 input_ids = input_ids.to(device, non_blocking=True)
        #                 attention_mask = attention_mask.to(device, non_blocking=True)

        #                 model_data.student_model.model(input_ids=input_ids, 
        #                                             #    attention_mask=attention_mask,
        #                                                use_cache=False, # avoid KV cache memory
        #                                                output_hidden_states=False,
        #                                                return_dict=False)
        #                 logger.debug(f'Passed batch {batch_idx} through the model.')
        #                 progress.update(task, advance=1)
        logger.info('Finished passing calibration data through the model to compute layer sensitivities and Hessian Matrix for each layer.')
        model_data.student_model.to("cpu")
        # Aggressive memory cleanup after forward pass
        import gc      
        gc.collect()
        # Force cleanup of any remaining tensors
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache()
        
        # Remove hooks after computing sensitivities and Hessian Matrix
        logger.debug('Removing hooks after computing sensitivities and Hessian Matrix for each layer...')
        for i, pruner in enumerate(pruners):
            pruner.remove_hook()
            pruner.remove_micropruner_hooks()
            logger.debug(f'Removed hooks for layer {i}')
        
        # Move all Hessians to CPU to free GPU memory
        logger.info('Moving all Hessians to CPU to free GPU memory...')
        for pruner in pruners:
            pruner.move_all_hessians_to_cpu()
        logger.info('All Hessians moved to CPU.')
        
        # Compute layer sensitivity
        logger.info('Computing layer sensitivities...')
        global_sensitivity: torch.Tensor = torch.zeros(len(pruners), device=device, dtype=torch.float32)
        with Progress(*columns) as progress:
            task = progress.add_task("[blue]Computing Sensitivities...", total=len(pruners))
            for i in range(len(pruners) - 1, -1, -1):
                ### Compute local sensitivity
                pruners[i].compute_local_sensitivity()
                ### Compute global sensitivity (layer sensitivity)
                if i == len(pruners) - 1:
                    pruners[i].set_global_sensitivity(pruners[i].local_sensitivity, 0.0)
                else:
                    pruners[i].set_global_sensitivity(global_sensitivity=pruners[i + 1].get_global_sensitivity(), alpha=prune_args.alpha)
                global_sensitivity[i] = pruners[i].get_global_sensitivity()
                logger.debug(f'Layer {i} global sensitivity: {global_sensitivity[i]} \t Local sensitivity: {pruners[i].local_sensitivity}')
                # Compute Hessian Matrix for the layer
                progress.update(task, advance=1)
        logger.info('Finished computing layer sensitivities.')
        # global_sensitivity = torch.clamp(global_sensitivity, min=1e-6)                 # Avoid division explosion
        # global_sensitivity = torch.log(global_sensitivity + 1e-6)
        # decay = compute_layerwise_decay_schedule(
        #     num_layers=len(pruners),
        #     schedule_type='gaussian',
        #     scale=1.0,
        #     steepness=6.0,   # stddev for Gaussian
        #     device='cuda'
        # )
        # global_sensitivity = global_sensitivity * decay
        # del decay
        # global_sensitivity = (global_sensitivity - global_sensitivity.min()) / (global_sensitivity.max() - global_sensitivity.min() + 1e-8)
        # global_sensitivity = global_sensitivity * (1 - 0.05) + 0.05  # e.g., floor_val = 0.01
        
        # global_sensitivity = (global_sensitivity - global_sensitivity.mean()) / (global_sensitivity.std(unbiased=False) + 1e-8)
        
        # global_sensitivity = torch.sigmoid(global_sensitivity)  # smoothly in [0, 1]
        
        # global_sensitivity = normalize_minmax(global_sensitivity)

        for i, pruner in enumerate(pruners):
            pruner.update_global_sensitivity(global_sensitivity[i].item())
        
        # torch.cuda.empty_cache()
        
        ### Step 1: Determine how many candidate blocks to consider per layer based on sensitivities and initial importances
        logger.info('Computing candidate blocks for pruning in each layer based on sensitivities and initial importances...')
        
        # Calculate total number of blocks in the model
        finite_mask = torch.isfinite(initial_importances[:, 0])
        total_blocks = initial_importances[finite_mask].shape[0]
        K_tot = int(min(max(1, round(prune_args.num_candidate_blocks * total_blocks)), total_blocks))
        ### get top-k blocks to prune
        logger.info('Selecting number of blocks to prune for each layer...')
        k = int(min(max(1, round(prune_args.num_blocks_to_prune * total_blocks)), total_blocks))
        logger.info(f"Number of blocks to prune: {k}")
        logger.info(f"Number of candidate blocks: {K_tot}, total blocks: {total_blocks}")
        if K_tot == 0:
            logger.warning("No candidates or nothing left to prune. Please check the model and schema. Stopping pruning loop gracefully.")
            return model_data.student_model
        
        candidate_blocks_per_layer = select_candidates_option_b(global_sensitivity, initial_importances, total_candidates=K_tot)
        logger.debug(f"Candidate blocks per layer: {candidate_blocks_per_layer}")
        
        ### Step 2: Select candidate blocks for each layer based on initial importances
        with Progress(*columns) as progress:
            task = progress.add_task("[blue]Setting Candidate blocks for each layer...", total=len(pruners))
            for i, pruner in enumerate(pruners):
                # Get importances for this layer (initial_importances has [importance, layer_idx, block_idx, block_type])
                layer_importances = initial_importances[initial_importances[:, 1] == i]
                
                # Get candidate blocks for this layer from the new tensor format
                layer_candidate_blocks = candidate_blocks_per_layer[candidate_blocks_per_layer[:, 1] == i]
                
                # Set candidate blocks for this HALPE module using the new method
                pruners[i].set_candidate_blocks(layer_candidate_blocks)
                logger.debug(f'Layer {i} candidate blocks: {len(layer_candidate_blocks)}')
                progress.update(task, advance=1)
        logger.info('Finished computing candidate blocks for pruning in each layer based on sensitivities and initial importances.')
        
        ### Finalize calibration for all pruners
        logger.info('Computing inverse Hessian matrices...')
        with Progress(*columns) as progress:
            task = progress.add_task("[blue]Computing Inverse Hessian Matrices...", total=len(pruners))
            for i, pruner in enumerate(pruners):
                candidate_blocks = pruner.get_candidate_blocks()
                if len(candidate_blocks) == 0:
                    progress.update(task, advance=1)
                    continue
                pruner.move_all_hessians_to_gpu()
                pruner.finalize_calibration(
                    min_damping=prune_args.min_damping, 
                    max_damping=prune_args.max_damping, 
                    max_iterative_iterations=prune_args.max_iterative_iterations,
                    iterative_tolerance=prune_args.iterative_tolerance
                )
                pruner.move_all_hessians_to_cpu()

                logger.debug(f'Layer {i} calibration finalized')
                # Force memory cleanup after each layer
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                progress.update(task, advance=1)
        logger.info('Finished finalizing calibration for all pruners.')
        
        ### compute exact importance of candidate blocks
        logger.info('Computing exact importance of candidate blocks...')
        exact_importances = []
        with Progress(*columns) as progress:
            task = progress.add_task("[blue]Computing Exact Importances...", total=len(pruners))
            for i, pruner in enumerate(pruners):
                candidate_blocks = pruner.get_candidate_blocks()
                if len(candidate_blocks) == 0:
                    progress.update(task, advance=1)
                    continue
                pruner.move_all_hessians_to_gpu()
                pruner.move_layer_to_gpu()
                exact_importances.append(pruner.compute_exact_importances(i))
                logger.debug(f'Layer {i} exact importances: {exact_importances[-1]}')
                pruner.move_all_hessians_to_cpu()
                pruner.move_layer_to_cpu()
                # Force memory cleanup after each layer
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                progress.update(task, advance=1)
        logger.info('Finished computing exact importance of candidate blocks.')
        min_max_normalize_tensor_list(exact_importances, 0)
        # meanstd_normalize_tensor_list(exact_importances, 0)
        # exact_importances[:, 0] = normalize_minmax(exact_importances[:, 0])
        
        if len(exact_importances) == 0:
            logger.error("exact_importances is empty! No candidate blocks were computed.")
            return model_data.student_model
        
        # Select the least important blocks from the candidate blocks we actually computed
        # The exact_importances already contain [importance, layer_idx, block_idx, block_type] for candidate blocks only
        selected_least_important_blocks = select_least_important_globally(exact_importances, k)
        
        # Convert tensor to per-layer tensors for direct use in pruning
        selected_blocks_per_layer = {}
        if len(selected_least_important_blocks) > 0:
            for layer_idx in range(len(pruners)):
                layer_mask = selected_least_important_blocks[:, 1] == layer_idx
                layer_blocks = selected_least_important_blocks[layer_mask]
                if len(layer_blocks) > 0:
                    selected_blocks_per_layer[layer_idx] = layer_blocks  # Keep the full tensor
                else:
                    selected_blocks_per_layer[layer_idx] = torch.empty(0, 4, device=selected_least_important_blocks.device, dtype=selected_least_important_blocks.dtype)
        else:
            selected_blocks_per_layer = {i: torch.empty(0, 4, device=selected_least_important_blocks.device, dtype=selected_least_important_blocks.dtype) for i in range(len(pruners))}
        
        logger.debug(f'Number of blocks to prune: {[selected_blocks_per_layer.get(i, torch.empty(0, 4)).shape[0] for i in range(len(pruners))]}')

        # Check if there are any blocks to prune
        total_blocks_to_prune = sum(blocks.shape[0] for blocks in selected_blocks_per_layer.values())
        if total_blocks_to_prune == 0:
            logger.warning("No blocks selected for pruning. This may indicate that all blocks have been pruned or no suitable candidates remain.")
            # Check if we've reached the target speedup or if no progress can be made
            current_speedup = inference_metric.get_current_speedup()
            if current_speedup is not None and current_speedup >= 0.0:
                logger.info(f"Target speedup achieved: {current_speedup:.2f}x")
                break
            else:
                logger.warning("No more blocks can be pruned and target speedup not achieved. Stopping pruning.")
                break
    
        logger.info('Pruning blocks in each layer...')
        with Progress(*columns) as progress:
            task = progress.add_task("[blue]Pruning Blocks...", total=len(pruners))
            task_pruned = progress.add_task(f"[blue]Pruning layer...", total=1, visible=False)
            for i, pruner in enumerate(pruners):
                # Get the selected blocks tensor for the current layer
                layer_blocks_tensor = selected_blocks_per_layer.get(i, torch.empty(0, 4))
                if layer_blocks_tensor.shape[0] == 0:
                    logger.warning(f"Skipping pruning for layer {i}: no selected blocks.")
                    progress.update(task, advance=1)
                    continue
                
                # progress.reset(task_pruned, completed=0, visible=True, total=layer_blocks_tensor.shape[0], description=f"[blue]Pruning layer {i}...")
                progress.update(task_pruned, completed=0, visible=True, total=layer_blocks_tensor.shape[0], description=f"[blue]Pruning layer {i}...")
                progress.start_task(task_pruned)
                pruner.move_layer_to_gpu()
                pruner.move_all_hessians_to_gpu()
                pruner.prune(layer_blocks_tensor, prune_args, i, progress, task_pruned, conditioned_score_max_chunk_size=prune_args.conditioned_score_max_chunk_size)
                progress.stop_task(task_pruned)
                progress.update(task_pruned, completed=layer_blocks_tensor.shape[0], visible=False)
                pruner.move_all_hessians_to_cpu()
                pruner.move_layer_to_cpu()
                logger.debug(f'Layer {i} pruned blocks: {pruner.get_pruned_blocks()}')
                progress.update(task, advance=1)
        logger.info('Finished pruning blocks in each layer.')
        
        inference_metric.compute_pruned_inference(model_layers, [pruner.get_updated_configs() for pruner in pruners])
        logger.info(f"Pruned inference: {inference_metric.get_pruned_inference()}, Original inference: {inference_metric.get_original_inference()}, Speedup: {inference_metric.get_current_speedup()}")
        
        # # Check if any individual layer has become too sparse (less than 5% of original blocks remain)
        # layer_too_sparse = False
        # for i, pruner in enumerate(pruners):
        #     layer_total_blocks = pruner.layer_config.num_heads + pruner.layer_config.intermediate_dimension
        #     layer_pruned_blocks = len(pruner.get_pruned_blocks())
        #     layer_remaining_blocks = layer_total_blocks - layer_pruned_blocks
        #     layer_sparsity_ratio = layer_remaining_blocks / layer_total_blocks if layer_total_blocks > 0 else 0
            
        #     if layer_sparsity_ratio < 0.05:  # Less than 5% of blocks remain in this layer
        #         logger.warning(f"Layer {i} has become too sparse ({layer_sparsity_ratio:.2%} blocks remain). Stopping pruning to prevent model degradation.")
        #         layer_too_sparse = True
        #         break
        
        # if layer_too_sparse:
        #     break
        
        # # Check if speedup is degrading (getting worse)
        # current_speedup = inference_metric.get_current_speedup()
        # if current_speedup is not None and hasattr(inference_metric, '_previous_speedup'):
        #     if current_speedup < inference_metric._previous_speedup * 0.95:  # 5% degradation threshold
        #         logger.warning(f"Speedup is degrading: {inference_metric._previous_speedup:.2f}x -> {current_speedup:.2f}x. Stopping pruning.")
        #         break
        # # Store current speedup for next iteration comparison
        # if current_speedup is not None:
        #     inference_metric._previous_speedup = current_speedup
        
        # # Check if the model has become too sparse overall (less than 10% of original blocks remain)
        # total_remaining_blocks = total_blocks - sum(len(pruner.get_pruned_blocks()) for pruner in pruners)
        # overall_sparsity_ratio = total_remaining_blocks / total_blocks if total_blocks > 0 else 0
        # if overall_sparsity_ratio < 0.1:  # Less than 10% of blocks remain
        #     logger.warning(f"Model has become too sparse overall ({overall_sparsity_ratio:.2%} blocks remain). Stopping pruning to prevent model degradation.")
        #     break
        
        if prune_args.do_module_reconstruction:
            # Reconstruct pruned modules
            for i, pruner in enumerate(pruners):
                pruner.do_layer_reconstruction()
        
        # Aggressive memory cleanup between iterations
        for pruner in pruners:
            pruner.reset_pruner()
        
        # Force garbage collection and memory cleanup
        # import gc
        # gc.collect()
        # torch.cuda.empty_cache()
        
        # Additional aggressive cleanup
        for pruner in pruners:
            # Force cleanup by calling reset_pruner again
            pruner.reset_pruner()
        
        # Final cleanup
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
        import gc
        gc.collect()
        
    # Check if we reached the target speedup or max iterations
    current_speedup = inference_metric.get_current_speedup()
    if iteration >= max_iterations:
        logger.warning(f"Reached maximum iterations ({max_iterations}) without achieving target speedup. Current speedup: {current_speedup:.2f}x")
    else:
        logger.info(f"Achieved target speedup in {iteration} iterations. Final speedup: {current_speedup:.2f}x")
    
    # Validation checks
    if inference_metric.get_pruned_inference() is None:
        raise ValueError("Pruned inference is None. Please check the model and schema.")
    if inference_metric.get_pruned_inference() > inference_metric.get_original_inference():
        raise ValueError("Pruned inference is greater than original inference. Please check the model and schema.")
    if inference_metric.get_pruned_inference() < 0:
        raise ValueError("Pruned inference is less than 0. Please check the model and schema.")
    if inference_metric.get_original_inference() is None:
        raise ValueError("Original inference is None. Please check the model and schema.")
    if inference_metric.get_original_inference() < 0:
        raise ValueError("Original inference is less than 0. Please check the model and schema.")
    if inference_metric.get_original_inference() == 0:
        raise ValueError("Original inference is 0. Please check the model and schema.")

    
    return model_data.student_model

# def get_pruned_model(logger: LoggerType, pruners: List[HALPE], model_data: ModelMetadata):
#     for i, pruner in enumerate(pruners):
#         pruned_layer_config = pruner.get_updated_configs()
#         pruned_layer = pruner.get_pruned_layer()
#         model_data.student_model.layers[i] = pruned_layer
#     return model_data.student_model