import torch
import torch.nn as nn
import math
from collections import OrderedDict
import numpy as np

from ..factorization._interface import BaseFactorization, ShapeHook
from ._sensitivity_base import SensitivityBasedSearch
from ..factorization._interface import get_eq_rank

def compute_task_loss(sensitivities: torch.Tensor, new_rank: int) -> float:
    pass

def tau_schedule(N0: int, N_target: int, t: int, gamma: float) -> float:
    """
    Exponential scheduling function from the paper (Eq. 9).
    Calculates the target number of parameters at a given iteration.

    Args:
        N0 (int): Initial number of parameters.
        N_target (int): Final target number of parameters.
        t (int): Current iteration step.
        gamma (float): The decay rate of the schedule.

    Returns:
        float: The target number of parameters for iteration t.
    """
    return N_target + (N0 - N_target) * math.exp(-t / gamma)

class MEMVITSearch(SensitivityBasedSearch):
    def __init__(self, eval_data, mixup_fn, name_omit=[], ratio_target=0.5, sensitivity_loss="energy1", target_metric="params", measurements_points="0.1-0.9", do_latency_adjustment=False, *args, **kwargs):
        super().__init__(
            eval_data=eval_data,
            mixup_fn=mixup_fn,
            name_omit=name_omit,
            ratio_target=ratio_target,
            sensitivity_loss=sensitivity_loss,
            measurements_points=measurements_points,
            *args,
            **kwargs
        )
        self.gamma: float = 80.0
        self.max_iter: int = 500
        # self.gamma: float = 200.0
        # self.max_iter: int = 5000
        self.energy_power: int = self.power_for_energy  # original approach uses squared energy, but that works bad for vision models.
        self.target_metric: str = target_metric
        self.do_latency_adjustment = do_latency_adjustment
        if self.do_latency_adjustment and not self.lrd_method.vision:
            raise ValueError("Latency adjustment is only supported for vision models.")
    
    @property
    def requires_decomposed_model_for_search(self):
        return True

    def get_layer_wise_flops(self, model):
        input_shapes = {}
        extractor = ShapeHook(model=model,
            name_omit=self.name_omit, dump_shape=False,
            name_prefix="", white_list=[])
        extractor.attach_hooks()
        device = next(model.parameters()).device
        dummy_input = torch.randn(20, 3, 224, 224).to(device)
        model(dummy_input)
        for key, value in extractor.input_shape.items():
            input_shapes[key] = value
        del dummy_input
        flops_per_layer = {}
        for layer_name, shape in input_shapes.items():
            flops_per_layer[layer_name] = shape[0] * shape[1] * shape[2] / 1000 * shape[3]
        
        return flops_per_layer, input_shapes
        

    def initialize_search(
        self, lrd_method: BaseFactorization, model: nn.Module, spec_tensor=None
    ):
        self.lrd_method = lrd_method
        # Step 1: Perform SVD once and store singular values
        layer_data = OrderedDict()
        if "energy" in self.sensitivity_loss: 
            with torch.no_grad():
                # TODO: use centralized valid Linear functions.
                for name, module in model.named_modules():
                    if all(omit not in name for omit in self.name_omit) and isinstance(module, nn.Linear):
                        factorized_matrix = self.lrd_method.factorize_matrix(
                            module.weight, ratio=1.0, name=name
                        )
                        # TODO: change to support ConvAsLinear, shape does not match otherwise.
                        layer_data[name] = {'S': factorized_matrix.singular_values, 'eq_rank': factorized_matrix.eq_rank,'shape': module.weight.shape}
        else:
            #raise NotImplementedError("Only energy-based sensitivity is implemented for MEMVIT, other modes are WIP.")
            from .rail import rolling_interpolate
            layer_sensitivity, _ = self._get_layer_sensitivity(model, spec_tensor)
            for layer_name, sensitivity_data in layer_sensitivity.items():
                matrix = model.get_submodule(layer_name).weight
                eq_rank = get_eq_rank(matrix.shape[0], matrix.shape[1])
                additional_values = rolling_interpolate(sensitivity_data, interpolation_points=10, approximate_to_95=False, max_rank_for_8_increments=eq_rank)
                layer_sensitivity[layer_name] = additional_values
                layer_data[layer_name] = {'S': None, 'eq_rank': eq_rank, 'shape': matrix.shape}
                self.layer_sensitivity = layer_sensitivity
        
        if self.lrd_method.vision:
            self.lower_bound = 0.1
        else:
            self.lower_bound = 0.3

        self.layer_data = layer_data

    def search(self, model: nn.Module):
        # create and sort sensitivity list required for search
        # Step 2 & 3: Initialize parameters and ranks
        if self.target_metric == "flops":
            self.flops_dict, input_shapes = self.get_layer_wise_flops(model=model)
        initial_complexity = 0
        current_ranks = OrderedDict()
        lower_bound_ranks = OrderedDict()
        for name, data in self.layer_data.items():
            n, d = data['shape']
            # Parameter count for a low-rank decomposed matrix A=UV^T is r*(n+d)
            if self.target_metric == "flops":
                initial_complexity += self.flops_dict[name]
            else:
                initial_complexity += n * d
            current_ranks[name] = data['eq_rank']
            lower_bound_ranks[name] = int(math.floor(data['eq_rank'] * self.lower_bound))

        p_total = initial_complexity
        p_current = initial_complexity
        p_target = p_total * self.ratio_target
        
        print(f"Initial parameters (decomposed): {p_total:,}")
        print(f"Target parameters (alpha={self.ratio_target}): {int(p_target):,}\n")

        # Step 4 & 5: Start iterative optimization
        t = 1
        while p_current > p_target and t <= self.max_iter:
            
            # Step 6: Determine number of parameters to remove in this step
            # This is the difference between the scheduled target for the previous and current step.
            p_target_t = tau_schedule(p_total, p_target, t, self.gamma)
            p_target_t_minus_1 = tau_schedule(p_total, p_target, t - 1, self.gamma)
            p_to_remove = p_target_t_minus_1 - p_target_t
            
            if p_to_remove <= 0:
                print("Parameter reduction schedule saturated. Halting.")
                break

            # Step 7 & 8: Initialize vectors for losses and temporary ranks
            energy_losses = {}
            temp_ranks = {}
            
            # Step 10-13: For each layer, calculate the energy loss for a hypothetical rank reduction
            for name, data in self.layer_data.items():
                current_rank = current_ranks[name]
                if current_rank == 1:
                    continue # Cannot reduce rank further
                    
                n, d = data['shape']
                
                # Calculate how much the rank would need to be reduced for this layer alone
                # to meet the parameter reduction target for this step.
                # p_to_remove = delta_r * (n + d) => delta_r = p_to_remove / (n + d)
                if self.target_metric == "flops":
                    delta_r = p_to_remove * self.layer_data[name]['eq_rank'] / self.flops_dict[name]
                else:
                    delta_r = p_to_remove / (n + d)
                
                # Calculate the hypothetical new rank
                m_t = math.floor(current_rank - delta_r)
                m_t = max(lower_bound_ranks[name], m_t) # Rank cannot be negative
                
                # If the proposed rank is not a reduction, skip
                if m_t >= current_rank:
                    continue

                temp_ranks[name] = m_t
                if "energy" in self.sensitivity_loss:
                    energy_losses[name] = self.compute_energy_loss(data['S'], m_t)
                else:
                    energy_losses[name] = self.get_task_loss(self.layer_sensitivity, self.layer_data, name, m_t)
                
            if not energy_losses:
                print("No further rank reduction possible. Halting.")
                break
                
            # Step 14: Find the layer with the minimum energy loss
            l_star = min(energy_losses, key=energy_losses.get)
            
            # Step 15: Update the rank of the chosen layer
            new_rank_l_star = temp_ranks[l_star]
            old_rank_l_star = current_ranks[l_star]
            current_ranks[l_star] = new_rank_l_star

            # Step 16: Update the current total parameter count by recalculating
            p_current = 0
            for name, r in current_ranks.items():
                n, d = self.layer_data[name]['shape']
                if self.target_metric == "flops":
                    p_current += r / self.layer_data[name]['eq_rank'] * self.flops_dict[name]
                else:
                    p_current += r * (n + d)
                
            print(f"Iter {t}: Best layer to compress is '{l_star}'. "
                f"Rank reduced from {old_rank_l_star} to {new_rank_l_star}. "
                f"Current params: {int(p_current):,}")
                
            # Step 17: Increment iteration counter
            t += 1
            
        print("\n--- Compression Finished ---")
        print(f"Final parameter count: {int(p_current):,} (Target was {int(p_target):,})")
        final_compression = p_total / p_current
        print(f"Achieved compression ratio: {final_compression:.2f}x")
            # return dict with per layer compression ratio
        if self.do_latency_adjustment and self.lrd_method.vision:
            current_ranks = self._basic_latency_rank_adjustment(current_ranks, input_shapes=input_shapes)
        return current_ranks
    
    def compute_energy_loss(self, singular_values: torch.Tensor, new_rank: int) -> float:
        """
        Computes the normalized energy loss for a given new rank.
        The energy loss is the ratio of the sum of squares of discarded singular values
        to the sum of squares of all singular values.

        Args:
            singular_values (torch.Tensor): A 1D tensor of singular values, sorted in descending order.
            new_rank (int): The number of singular values to keep.

        Returns:
            float: The normalized energy loss, a value between 0 and 1.
        """
        # Ensure rank is a non-negative integer
        new_rank = int(max(0, new_rank))
        
        # Total energy is the squared Frobenius norm (sum of squared singular values)
        total_energy = torch.sum(singular_values**self.energy_power)
        
        if total_energy == 0:
            return 0.0

        # Energy of the singular values that are discarded
        lost_energy = torch.sum(singular_values[new_rank:]**self.energy_power)
        
        # normalized energy does stabilize this approach for LLMs, without it, it fails.
        return (lost_energy / total_energy).item()
    
    def get_task_loss(self, layer_sensitivity, layer_data, name: torch.Tensor, new_rank: int) -> float:
        eq_rank = layer_data[name]['eq_rank']
        #print(layer_sensitivity[name])
        closest_key = min(layer_sensitivity[name].keys(), key=lambda x: abs(x - np.round(new_rank / eq_rank, 5)))
        return layer_sensitivity[name][closest_key]
        #return layer_sensitivity[name][np.round(new_rank/eq_rank, 5)]
    
    def _basic_latency_rank_adjustment(self, layerwise_rank_dict, input_shapes):
        import warnings
        import joblib
        latency_predictor_path: str = "/workspace/KFAC-SVD/rf_model_kx8_fp32_80us.pkl"
        if latency_predictor_path:
            print(
                f"Found latency predictor, {latency_predictor_path.split('/')[-1]}. Start loading..."
            )
            latency_predictor = joblib.load(latency_predictor_path) 
        else:
            return layerwise_rank_dict
        def _get_latency_pred(latency_predictor, input_shape, rank):
            input2regressor = input_shape
            input2regressor[4] = rank
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                lat_predict = latency_predictor.predict([input2regressor])
            return lat_predict
        print("Adjusting ranks based on latency predictor...")
        for layer_name, rank in layerwise_rank_dict.items():
            if rank != -1:
                # Filter out those ranks that do not comply with our requirements.
                if layer_name in input_shapes:
                    input_shape = input_shapes[layer_name]
                    print(f"Layer: {layer_name}, Input shape: {input_shape}, Rank: {rank}")
                    lat_predict = _get_latency_pred(latency_predictor, input_shape, int(layerwise_rank_dict[layer_name]))
                    # Check if the predicted latency lower than the uncompressed layer
                    if lat_predict.item() >= 1.0 or lat_predict.item() <= 0:
                        layerwise_rank_dict[layer_name] = -1
        return layerwise_rank_dict
    
