from torch import nn

from ..factorization._interface import BaseFactorization, ShapeHook
from ._sensitivity_base import SensitivityBasedSearch
import torch


class ASVDSearch(SensitivityBasedSearch):
    """
    @misc{
        yuan2025asvd,
        title={{ASVD}: Activation-aware Singular Value Decomposition for Compressing Large Language Models},
        author={Zhihang Yuan and Yuzhang Shang and Yue Song and Dawei Yang and Qiang Wu and Yan Yan and Guangyu Sun},
        year={2025},
        url={https://openreview.net/forum?id=HyPofygOCT}
    }
    """
    def __init__(self, ratio_target=0.5, sensitivity_loss="ce", measurements_points="asvd_default", target_metric="params", do_latency_adjustment=False, *args, **kwargs):
        super().__init__(
            ratio_target=ratio_target,
            sensitivity_loss=sensitivity_loss,
            measurements_points=measurements_points,
            *args,
            **kwargs,
        )
        self.target_metric = target_metric
        self.do_latency_adjustment = do_latency_adjustment
        if self.do_latency_adjustment and not self.lrd_method.vision:
            raise ValueError("Latency adjustment is only supported for vision models.")
    
    @property
    def requires_decomposed_model_for_search(self):
        return True

    def initialize_search(
        self, lrd_method: BaseFactorization, model: nn.Module, spec_tensor=None
    ):
        self.lrd_method = lrd_method
        layer_sensitivity, _ = self._get_layer_sensitivity(model, spec_tensor)
        self.sensitivity_dict = layer_sensitivity
    
    def get_layer_wise_flops(self, model):
        input_shapes = {}
        extractor = ShapeHook(model=model,
            name_omit=self.name_omit, dump_shape=False,
            name_prefix="", white_list=[])
        extractor.attach_hooks()
        device = next(model.parameters()).device
        if device.type == 'cpu':
            print("Running on CPU")
            device = "cuda" if torch.cuda.is_available() else "cpu"
            model = model.to(device)
        else:
            print("Running on GPU")
        dummy_input = torch.randn(20, 3, 224, 224).to(device)
        model(dummy_input)
        for key, value in extractor.input_shape.items():
            input_shapes[key] = value
        del dummy_input
        flops_per_layer = {}
        for layer_name, shape in input_shapes.items():
            flops_per_layer[layer_name] = shape[0] * shape[1] * shape[2] / 1000 * shape[3]
        
        extractor.clear_hooks()
        return flops_per_layer, input_shapes

    def search(self, model: nn.Module):
        module_dict = {name: module for name, module in model.named_modules()}
        # get layer wise flops
        
        if self.target_metric == "flops":
            flops_per_layer, input_shapes = self.get_layer_wise_flops(model)

        default_param_ratio = 1.0

        # create and sort sensitivity list required for search
        sensitivity_list = []
        for layername, v in self.sensitivity_dict.items():
            for param_ratio, ppl in v.items():
                if param_ratio >= 1:
                    continue
                sensitivity_list.append((layername, param_ratio, ppl))
        sorted_sensitive_list = sorted(sensitivity_list, key=lambda x: -x[2])

        # binary search start
        high = len(sorted_sensitive_list) - 1
        low = 0

        while low < high:
            mid = (low + high) // 2
            layers_min_ratio = {
                layername: default_param_ratio
                for layername in self.sensitivity_dict.keys()
            }
            for layername, param_ratio, ppl in sorted_sensitive_list[mid:]:
                layers_min_ratio[layername] = min(
                    layers_min_ratio[layername], param_ratio
                )
            tot_params = 0
            compress_params = 0

            for layername, param_ratio in layers_min_ratio.items():
                raw_linear = module_dict[layername]
                # TODO: make this flops compatible!
                if self.target_metric == "flops":
                    tot_params += flops_per_layer[layername]
                    compress_params += flops_per_layer[layername] * param_ratio
                else:
                    tot_params += raw_linear.weight.numel()
                    compress_params += raw_linear.weight.numel() * param_ratio
            now_ratio = compress_params / tot_params
            if now_ratio > self.ratio_target:
                high = mid
            else:
                low = mid + 1

        print("=== Searching done, decomposing layers... ===")
        layers_min_ratio = {
            layername: default_param_ratio for layername in self.sensitivity_dict.keys()
        }
        for layername, param_ratio, ppl in sorted_sensitive_list[mid:]:
            if layers_min_ratio[layername] is None:
                layers_min_ratio[layername] = param_ratio
            else:
                layers_min_ratio[layername] = min(
                    layers_min_ratio[layername], param_ratio
                )
        
        cumulative_error = 0.0
        for layername, param_ratio in layers_min_ratio.items():
            if param_ratio != default_param_ratio:
                cumulative_error += self.sensitivity_dict[layername][param_ratio]
        print(f"cumulative error: {cumulative_error}")
        # return dict with per layer compression ratio
        if self.do_latency_adjustment and self.lrd_method.vision:
            layers_min_ratio = self._basic_latency_rank_adjustment(layers_min_ratio, input_shapes)
        return layers_min_ratio
    
    def _basic_latency_rank_adjustment(self, layerwise_rank_dict, input_shapes):
        import warnings
        import joblib
        latency_predictor_path: str = "/workspace/KFAC-SVD/rf_model_kx8_fp32_80us.pkl"
        if latency_predictor_path:
            print(
                f"Found latency predictor, {latency_predictor_path.split('/')[-1]}. Start loading..."
            )
            latency_predictor = joblib.load(latency_predictor_path) 
        else:
            return layerwise_rank_dict
        def _get_latency_pred(latency_predictor, input_shape, rank):
            input2regressor = input_shape
            input2regressor[4] = rank
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                lat_predict = latency_predictor.predict([input2regressor])
            return lat_predict
        print("Adjusting ranks based on latency predictor...")
        for layer_name, rank in layerwise_rank_dict.items():
            if rank != -1:
                # Filter out those ranks that do not comply with our requirements.
                if layer_name in input_shapes:
                    input_shape = input_shapes[layer_name]
                    print(f"Layer: {layer_name}, Input shape: {input_shape}, Rank: {rank}")
                    lat_predict = _get_latency_pred(latency_predictor, input_shape, int(layerwise_rank_dict[layer_name]))
                    # Check if the predicted latency lower than the uncompressed layer
                    if lat_predict.item() >= 1.0 or lat_predict.item() <= 0:
                        layerwise_rank_dict[layer_name] = -1
        return layerwise_rank_dict