from copy import deepcopy

import all_utils.dist_training as dist_utils
import torch
from torch import nn
from all_utils.other import get_model_name

from ..factorization._interface import BaseFactorization, get_valid_layers
from ._interface import BaseSearch
import torch.nn.functional as F
import numpy as np
from all_utils.distances import wasserstein_from_logits_3d_fast, jsd_from_logits_3d, bild_loss
import pickle
import os


class LastFeatureHook:
    def __init__(self, model: nn.Module):
        self.model = model

        self.hooks = []
        self.cp_modules = reversed(
            [
                (name, module_sub)
                for name, module_sub in model.named_modules()
                # if all(omit not in name for omit in name_omit)
                if isinstance(module_sub, nn.Linear)
            ]
        ) # TODO: use centralized valid Linear functions.

    def _hook_fn(self, layer_name):
        def get_feature_extract_hook(module, input, output):
            if "head" in layer_name:
                x = input[0].detach().float()
                if x.dim() > 3:
                    x = x.reshape(x.shape[0], -1, x.shape[-1])
                elif x.dim() == 2:
                    x = x.unsqueeze(0)
                self.model.last_feat = x.clone()
                # self.model.last_feat = output   # if layer: blocks.-1.mlp.fc2

        return get_feature_extract_hook

    def _register_hooks_recursive(self):
        for name, layer in self.cp_modules:
            if layer.out_features < 10:
                continue  # for some head matrix, such as image-text match head

            hook = layer.register_forward_hook(self._hook_fn(name))
            self.hooks.append(hook)
            # if "head" in name: # continue
            return

    def attach_hooks(self):
        self._register_hooks_recursive()

    def clear_hooks(self):
        for hook in self.hooks:
            hook.remove()

class SensitivityBasedSearch(BaseSearch):
    def __init__(self, eval_data, mixup_fn, name_omit=[], ratio_target=0.5, sensitivity_loss="kl", measurements_points="0.1-0.9", sequence_length=256, use_cache=False):
        self.eval_data = tuple(data for data in eval_data)
        self.name_omit = name_omit
        self.mixup_fn = mixup_fn
        # sensitivity dict needed for search
        self.sensitivity_dict = {}
        self.lrd_method = None
        self.ratio_target = ratio_target
        # specific parameters
        self.sensitivity_loss = sensitivity_loss
        self.sequence_length = sequence_length
        # 1 for energy, 2 for squared energy (2 somehow is terrible)
        if "energy1" in sensitivity_loss:
            self.power_for_energy = 1
        else:
            self.power_for_energy = 2
        self.use_cache = use_cache
        self.sensitivity_cache_dir = "./.cache/sensitivity_cache/"

        if measurements_points == "0.1-0.9":
            self.measurements_points = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        elif measurements_points == "0.2-0.9":
            self.measurements_points = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        elif measurements_points == "0.1-0.9uneven":
            self.measurements_points = [0.1, 0.3, 0.5, 0.7, 0.9]
        elif measurements_points == "asvd_default":
            self.measurements_points = [0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        elif measurements_points == "gfwsvd":
            self.measurements_points = [i / 20.0 for i in range(1, 20)]
        elif measurements_points == "0.1":
            self.measurements_points = [0.1]
        elif measurements_points == "0.3":
            self.measurements_points = [0.3]
        elif measurements_points == "0.5":
            self.measurements_points = [0.5]
        elif measurements_points == "0.7":
            self.measurements_points = [0.7]
        else:
            raise ValueError(
                f"Unknown measurements points: {measurements_points}. "
                "Use '0.1', '0.3', '0.5', '0.7', '0.1-0.9', '0.2-0.9', '0.1-0.9uneven', 'gfwsvd' or 'asvd_default'."
            )
        print(f"Using measurement points {self.measurements_points} for the sensitivity search.")
    
    @property
    def requires_decomposed_model_for_search(self):
        return True

    def initialize_search(
        self, lrd_method: BaseFactorization, model: nn.Module, spec_tensor=None
    ):
        self.lrd_method = lrd_method
        layer_sensitivity, size_dict = self._get_layer_sensitivity(model, spec_tensor)
        self.size_dict = size_dict
        self.sensitivity_dict = layer_sensitivity

    def search(self, model: nn.Module):
        raise NotImplementedError("Subclasses should implement this method.")
    
    def _precompute_original_outputs(self, model: nn.Module):
        """
        Key Change 1: Pre-compute and store the original model's outputs once.
        This avoids repeated forward passes with the unmodified model.
        """
        dev = torch.device(torch.cuda.current_device())
        model = model.to(dev).eval()
        
        original_outputs = []
        
        # Pre-computation does not need hooks if we can get logits directly
        # For MSE, we temporarily add a hook just for the pre-computation pass
        hook_register_model = None
        if "mse" in self.sensitivity_loss:
            hook_register_model = LastFeatureHook(model)
            hook_register_model.attach_hooks()

        print("Pre-computing original model outputs for reference...")
        with torch.no_grad():
            for batch in self.eval_data:
                if self.lrd_method.vision:
                    samples, _ = batch
                    model_inputs = samples.to(dev)
                    outputs = model(model_inputs)
                    if self.sensitivity_loss == "mse":
                        # Store features on CPU to save VRAM
                        original_outputs.append(model.last_feat.clone().detach().cpu())
                    else: # "kl" or "ppl"
                        # Store logits on CPU to save VRAM
                        original_outputs.append(outputs.clone().detach().cpu())
                else:
                    batch = {k: v.to(dev) for k, v in batch.items()}
                    outputs = model(**batch)
                    
                    if "mse" in self.sensitivity_loss:
                        # Store features on CPU to save VRAM
                        original_outputs.append(model.last_feat.clone().detach().cpu())
                    else: # "kl" or "ppl"
                        # Store logits on CPU to save VRAM
                        original_outputs.append(outputs.logits.clone().detach().cpu())

        if hook_register_model:
            hook_register_model.clear_hooks()
            
        # Move model back to CPU to free up VRAM before starting the sensitivity analysis
        # model.cpu()
        torch.cuda.empty_cache()
        print("Pre-computing original model outputs done.")
        return original_outputs

    def _eval_llm(self, cp_model, original_outputs):
        """
        The evaluation function takes the modified model and the 
        pre-computed outputs, avoiding the need for a second model on the GPU.
        """
        dev = torch.device(torch.cuda.current_device())
        # cp_model is already on the correct device and in eval mode
        cp_model.eval()

        total_loss = 0.0
        num_batches = 0
        
        # For MSE, we need a hook on the modified model
        hook_register_model_copy = None
        if "mse" in self.sensitivity_loss:
            hook_register_model_copy = LastFeatureHook(cp_model)
            hook_register_model_copy.attach_hooks()

        nlls = []
        with torch.no_grad():
            # Zip the dataloader with the pre-computed outputs
            for batch, ref_output in zip(self.eval_data, original_outputs):
                batch = {k: v.to(dev) for k, v in batch.items()}
                outputs_cp = cp_model(**batch)

                if "mse" in self.sensitivity_loss:
                    # Move reference output to GPU for this batch's calculation
                    ref_last_feat = ref_output.to(dev)
                    L_fm = F.mse_loss(cp_model.last_feat, ref_last_feat)
                    loss = L_fm / torch.mean(ref_last_feat**2)
                elif "kl" in self.sensitivity_loss:
                    # Move reference logits to GPU for this batch's calculation
                    ref_logits = ref_output.to(dev)
                    if torch.isfinite(ref_logits / 0.6).all() and torch.isfinite(outputs_cp.logits / 0.6).all():
                        probs_target = F.softmax(ref_logits / 0.6, dim=-1).reshape(self.sequence_length, -1)[256:]
                        probs_cp = F.log_softmax(outputs_cp.logits / 0.6, dim=-1).reshape(self.sequence_length, -1)[256:]
                        loss = F.kl_div(probs_cp, probs_target, reduction='batchmean')
                        loss = torch.mean(loss)
                    else:
                        loss = torch.tensor(0.0)  # not breaking compatability.
                elif "ppl" in self.sensitivity_loss:
                    lm_logits = outputs_cp.logits
                    if torch.isfinite(lm_logits).all():
                        shift_logits = lm_logits[:, :-1, :].contiguous()
                        shift_labels = batch["input_ids"][:, 1:].contiguous()
                        
                        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
                        loss_ = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
                        nlls.append(loss_)
                        loss = torch.tensor(0.0)  # not breaking compatability.
                elif "wasserstein" in self.sensitivity_loss:
                    w_distance = wasserstein_from_logits_3d_fast(outputs_cp.logits, ref_output.to(dev))
                    loss = torch.mean(w_distance) / self.sequence_length
                elif self.sensitivity_loss == "jsd":
                    # Compute JSD
                    js_divergence = jsd_from_logits_3d(outputs_cp.logits, ref_output.to(dev))
                    # print(f"Jensen-Shannon Divergence: {js_divergence.item():.4f}")
                    loss = js_divergence
                    # Example with 'none' reduction to get per-item scores
                    # jsd_per_item = jsd_from_logits(model1_logits, model2_logits, reduction='none')
                    # print(f"JSD per batch item:\n{jsd_per_item}")
                elif self.sensitivity_loss == "bild":
                    # https://github.com/fpcsong/BiLD/blob/main/trainers/custom_trainer.py#L315
                    logits_s = outputs_cp.logits
                    t_ld_loss = bild_loss(logits_s, ref_output.to(dev), top_k=8, temperature=3.0, student_led=False)
                    s_ld_loss = bild_loss(logits_s, ref_output.to(dev), top_k=8, temperature=3.0, student_led=True)
                    # this does not seem to be applicable for us.
                    #t_ld_loss = (t_ld_loss.view(-1) * distil_loss_mask).sum() / distil_loss_mask.sum()
                    #s_ld_loss = (s_ld_loss.view(-1) * distil_loss_mask).sum() / distil_loss_mask.sum()
                    loss = torch.mean(t_ld_loss + s_ld_loss)
                
                total_loss += loss.item()
                num_batches += 1

        if hook_register_model_copy:
            hook_register_model_copy.clear_hooks()

        # The metric calculation might need adjustment depending on its definition
        # Here, we average the loss. For PPL, you'd typically exponentiate the average loss.
        avg_loss = total_loss / num_batches
        metric = np.exp(torch.cat(nlls, dim=-1).mean().item())if self.sensitivity_loss == "ppl" else avg_loss

        # Synchronization for distributed training
        if dist_utils.is_dist_initialized():
            metric_tensor = torch.tensor(metric, device=dev)
            dist_utils.dist_barrier()
            metric_tensor = dist_utils.sync_tensor(metric_tensor, "mean")
            dist_utils.dist_barrier()
            return metric_tensor.item()
            
        return metric
    
    def _eval_vision(self, cp_model, original_outputs):
        dev = torch.device(torch.cuda.current_device())
        cp_model = cp_model.to(dev)
        cp_model = cp_model.eval()
        #loss_fc = nn.CrossEntropyLoss()
        sensitivity = torch.tensor(0.0, device=dev)
        if "mse" in self.sensitivity_loss:
            # prepare hooks for last feature extraction needed for MSE
            hook_register_model_copy = LastFeatureHook(cp_model)
            hook_register_model_copy.attach_hooks()
        with torch.no_grad():
            for (samples, targets), orig_out in zip(self.eval_data, original_outputs):
                model_inputs, labels = samples.to(dev), targets.to(dev)
                outputs_cp = cp_model(model_inputs)
                if "mse" in self.sensitivity_loss:
                    outputs_neutral = orig_out.to(dev)
                    L_fm = F.mse_loss(orig_out, cp_model.last_feat)
                    loss = L_fm / torch.mean(cp_model.last_feat**2)
                elif "kl" in self.sensitivity_loss:
                    outputs_neutral = orig_out.to(dev)
                    outputs_neutral = outputs_neutral / 0.6
                    outputs_cp = outputs_cp / 0.6
                    probs_target = F.softmax(outputs_neutral.detach(), dim=-1)
                    probs_cp = F.log_softmax(outputs_cp, dim=-1)
                    del model_inputs
                    loss = F.kl_div(probs_cp, probs_target, reduction='batchmean')
                elif "ce" in self.sensitivity_loss:
                    outputs_cp = F.softmax(outputs_cp.detach(), dim=-1)
                    loss = F.cross_entropy(outputs_cp, labels, reduction='mean')
                else:
                    raise ValueError(f"Unknown sensitivity loss type: {self.sensitivity_loss}")
                sensitivity += loss
                # ppl = torch.exp(loss)
                del loss, outputs_cp, outputs_neutral, orig_out# , probs_target, probs_cp

                with torch.cuda.device(torch.cuda.current_device()):
                    torch.cuda.empty_cache()
        
        if "mse" in self.sensitivity_loss:
            hook_register_model_copy.clear_hooks()
        dist_utils.dist_barrier()
        if dist_utils.is_dist_initialized():
            sensitivity = dist_utils.sync_tensor(sensitivity, "mean")
        dist_utils.dist_barrier()

        return sensitivity.item()

    def _get_layer_sensitivity(self, model: nn.Module, spec_tensor=None):
        cache_loaded = False

        model_name = get_model_name(model)
        
        if self.use_cache:
            try:
                with open(f"{self.sensitivity_cache_dir}layer_sensitivity_{self.lrd_method.get_cache_name()}_{self.sensitivity_loss}_{model_name}.pkl", "rb") as f:
                    layer_sensitivity = pickle.load(f)
                with open(f"{self.sensitivity_cache_dir}size_dict_{model_name}.pkl", "rb") as f:
                    size_dict = pickle.load(f)
                cache_loaded = True
                for layer_name, sensitivity_data in layer_sensitivity.items():
                    layer_sensitivity[layer_name] = {k: v for k, v in sensitivity_data.items() if k >= 0.1 and k <= 0.95}
                print("Loaded cached layer sensitivity data.")
            except FileNotFoundError:
                print("No cached sensitivity data found. Proceeding with new calculations.")

        
        if not cache_loaded:
            layer_sensitivity, size_dict = self._compute_layer_sensitivity(model, spec_tensor)
            if self.use_cache:
                if not os.path.exists(self.sensitivity_cache_dir):
                    os.makedirs(self.sensitivity_cache_dir)
                with open(f"{self.sensitivity_cache_dir}layer_sensitivity_{self.lrd_method.get_cache_name()}_{self.sensitivity_loss}_{model_name}.pkl", "wb") as f:
                    pickle.dump(layer_sensitivity, f)
                with open(f"{self.sensitivity_cache_dir}size_dict_{model_name}.pkl", "wb") as f:
                    pickle.dump(size_dict, f)
                print("Saved layer sensitivity data to cache.")
        
        return layer_sensitivity, size_dict

    def _compute_layer_sensitivity(self, model: nn.Module, spec_tensor=None):
        model_forwards_required = "energy" not in self.sensitivity_loss or "_klscaled" in self.sensitivity_loss or "_msescaled" in self.sensitivity_loss or "_pplscaled" in self.sensitivity_loss
        if model_forwards_required:
            original_outputs = self._precompute_original_outputs(model)
        
        dev = torch.device(torch.cuda.current_device())
        model = model.to(dev) # Move model to GPU once

        sensitivity_dict = {}
        size_dict = {}
        
        # We iterate over a copy of items to allow modification of the model during the loop
        # TODO: use centralized valid Linear functions.
        for name, module_sub in list(model.named_modules()):
            if isinstance(module_sub, nn.Linear):
                if any(n in name for n in self.name_omit) or module_sub.out_features < 10:
                    continue

                print(f"Evaluating sensitivity for layer {name}")

                # Key Change 2: Instead of deepcopy, we will modify the model in-place and restore it.
                # Find the parent module to allow replacement
                base, localname = model, name
                while "." in localname:
                    prefix, localname = localname.split(".", 1)
                    base = base.__getattr__(prefix)

                sensitivity_dict[name] = {}
                size_dict[name] = module_sub.weight.numel()
                
                factorized_matrix = self.lrd_method.factorize_matrix(
                    name=name, matrix=module_sub.weight, ratio=1.0
                )

                if "energy" in self.sensitivity_loss:
                    max_rank = factorized_matrix.eq_rank
                    factorized_matrix.singular_values = factorized_matrix.singular_values.float()
                    total_energy = sum(torch.pow(factorized_matrix.singular_values, self.power_for_energy))
                    klscaled_reference_point_energy = sum(torch.pow(factorized_matrix.singular_values[:int(max_rank * self.measurements_points[0])], self.power_for_energy))
                    eq_rank_energy_loss = sum(torch.pow(factorized_matrix.singular_values[max_rank:], self.power_for_energy))
                    for rank in range(int(max_rank * 0.1), int(max_rank * 0.95)):
                        remaining_energy = sum(torch.pow(factorized_matrix.singular_values[:rank], self.power_for_energy))
                        if self.sensitivity_loss == "energy2_eqoffset":
                            removed_energy = (total_energy - remaining_energy - eq_rank_energy_loss) / max_rank
                        elif "scaled" in self.sensitivity_loss and "normal_" in self.sensitivity_loss:
                            removed_energy = (1 - remaining_energy/total_energy) / torch.max((1 - klscaled_reference_point_energy/total_energy), torch.tensor(1e-6, device=dev))
                        elif "normal" in self.sensitivity_loss:
                            removed_energy = (total_energy - remaining_energy) / total_energy
                        else:
                            removed_energy = (total_energy - remaining_energy) / factorized_matrix.singular_values.shape[0]
                        sensitivity_dict[name][(rank/max_rank)] = removed_energy.clone().detach().cpu().numpy().item()
                # if measurement based sensitivities are requested (or the energy-kl mixture sensitivity)
                if model_forwards_required:
                    for ratio in self.measurements_points:
                        eval_rank = int(factorized_matrix.eq_rank * ratio)
                        factorized_matrix.active_rank = eval_rank
                        
                        seq_replacement = self.lrd_method.create_factorized_sequential(
                            factorized_matrix=factorized_matrix, original_module=module_sub
                        ).to(dev)

                        setattr(base, localname, seq_replacement)
                        
                        if self.lrd_method.vision:
                            metric = self._eval_vision(model, original_outputs)
                        else:
                            metric = self._eval_llm(model, original_outputs)
                        
                        if "_klscaled" in self.sensitivity_loss or "_msescaled" in self.sensitivity_loss or "_pplscaled" in self.sensitivity_loss:
                            for ratio, sensitivity in sensitivity_dict[name].items():
                                sensitivity_dict[name][ratio] = sensitivity * metric
                        else:
                            sensitivity_dict[name][ratio] = metric
                
                # Key Change 2 (continued): Restore the original layer
                setattr(base, localname, module_sub)
                
        # Clean up at the end
        model.cpu()
        with torch.cuda.device(dev):
            torch.cuda.empty_cache()
                        
        return sensitivity_dict, size_dict
    
    def get_layer_shape_dict(self, model: nn.Module):
        shape_dict = {}
        valid_modules = get_valid_layers(model, self.name_omit, white_list=[])
        for (name, module) in valid_modules:
            shape_dict[name] = (module.in_features, module.out_features)
        return shape_dict

