import torch
import torch.nn as nn
import transformers
import json
import os

class PEFTEigenvalueCalculator:
    def __init__(self, model: nn.Module, last_k_layers: int = None):
        self.model = model
        self.layer_eigenvalues = {}
        self.hooks = []
        
        
        self.f_l_minus_1_avg = {}
        self.J_l_avg = {}
        self.grad_l_avg = {}  
        self.batch_count = {}
        
        
        self.trainable_layers = {}
        
        
        self.layer_outputs = {}
        self.layer_inputs = {}  
        
        
        self.last_k_layers = last_k_layers
    
    def register_hooks(self):
        
        
        self._find_trainable_layers()
        
        
        for layer_id, module_info in self.trainable_layers.items():
            self._register_layer_hooks(module_info['module'], layer_id)
    
    def _find_trainable_layers(self):
        
        all_trainable_layers = {}
        layer_counter = 0  
        
        def find_linear_layers_recursive(module, parent_name=""):
            nonlocal layer_counter
            for name, child in module.named_children():
                full_name = f"{parent_name}.{name}" if parent_name else name
                
                
                if isinstance(child, nn.Linear):
                    
                    if child.weight.requires_grad:
                        all_trainable_layers[layer_counter] = {
                            'name': full_name,
                            'module': child,
                            'weight_param_name': f"{full_name}.weight"
                        }
                        layer_counter += 1
                else:
                    
                    find_linear_layers_recursive(child, full_name)
        
        
        find_linear_layers_recursive(self.model)
        
        
        if self.last_k_layers is not None:
            
            sorted_layer_ids = sorted(all_trainable_layers.keys())
            k = min(max(1, self.last_k_layers), len(sorted_layer_ids))  
            selected_layer_ids = sorted_layer_ids[-k:]
            
            
            self.trainable_layers = {
                layer_id: all_trainable_layers[layer_id] 
                for layer_id in selected_layer_ids
            }
        else:
            
            self.trainable_layers = all_trainable_layers
        
        
        for layer_id, layer_info in self.trainable_layers.items():
            param_count = sum(p.numel() for p in layer_info['module'].parameters() if p.requires_grad and p is layer_info['module'].weight)
            
        
    
    def _register_layer_hooks(self, module: nn.Module, layer_id: int):
        
        def forward_hook(module, input, output):
            
            f_l_minus_1 = input[0].detach()  
            f_l = output  

            self.layer_outputs[layer_id] = f_l
            
            
            if layer_id not in self.layer_inputs:
                self.layer_inputs[layer_id] = f_l_minus_1
            else:
                
                current_avg = self.layer_inputs[layer_id]
                if current_avg.shape[0] != f_l_minus_1.shape[0]:
                    
                    return
                self.layer_inputs[layer_id] = current_avg + (f_l_minus_1 - current_avg) / self.batch_count[layer_id]
            
            
            if layer_id not in self.f_l_minus_1_avg:
                self.f_l_minus_1_avg[layer_id] = f_l_minus_1.mean(dim=0)  
                self.batch_count[layer_id] = 1
            else:
                
                current_avg = self.f_l_minus_1_avg[layer_id]
                batch_avg = f_l_minus_1.mean(dim=0)
                self.batch_count[layer_id] += 1
                self.f_l_minus_1_avg[layer_id] = current_avg + (batch_avg - current_avg) / self.batch_count[layer_id]
        
        def backward_hook(module, grad_input, grad_output):
            
            if grad_output[0] is not None and layer_id in self.layer_inputs:
                
                
                output_grad = grad_output[0]  
                input_data = self.layer_inputs[layer_id]  

                
                if output_grad.shape[0] != input_data.shape[0]:
                    return
                
                
                
                param_grad = output_grad.T @ input_data  
                
                
                if layer_id not in self.grad_l_avg:
                    self.grad_l_avg[layer_id] = param_grad.clone()
                else:
                    
                    current_avg = self.grad_l_avg[layer_id]
                    self.grad_l_avg[layer_id] = current_avg + (param_grad - current_avg) / self.batch_count[layer_id]
        
        
        forward_handle = module.register_forward_hook(forward_hook)
        backward_handle = module.register_full_backward_hook(backward_hook)
        self.hooks.append(forward_handle)
        self.hooks.append(backward_handle)
    
    def compute_S_matrix_batch_avg(self, 
                                 f_l_minus_1_avg: torch.Tensor, J_l_avg: torch.Tensor,
                                 gamma: float = 1.0, K: float = 1.0) -> torch.Tensor:
                                 

        
        
        
        
        
        
        f = f_l_minus_1_avg.view(-1, 1)         
        J = J_l_avg
        
        
        C = (f @ f.T) / K
        
        B = gamma * (J.T @ J)
        
        S = torch.kron(C, B)
        return S
    

    def compute_min_two_eigenvalues(self, S: torch.Tensor, zero_tol: float = 1e-8, gap_tol: float = 5e-8) -> tuple:
        
        
        
        

        original_dtype = S.dtype
        S_reg = S.float() 
        
        
        eigvals, eigvecs = torch.linalg.eigh(S_reg)

        eigvals = eigvals.to(original_dtype)
        eigvecs = eigvecs.to(original_dtype)
        
        
        neg_mask = eigvals < 0
        if torch.any(neg_mask):
            eigvals[neg_mask] = -eigvals[neg_mask]
            eigvecs[:, neg_mask] = -eigvecs[:, neg_mask]
        
        
        sorted_indices = torch.argsort(eigvals)  
        eigvals = eigvals[sorted_indices]
        eigvecs = eigvecs[:, sorted_indices]

        
        
        
        
        
        
        
        
        idx1 = None
        for i in range(len(eigvals)):
            if eigvals[i] > zero_tol:
                idx1 = i
                break
        
        if idx1 is None:
            print(f"[Warning] No eigenvalues > zero_tol found, returning None.")
            return None, None, None
        
        
        idx2 = None
        for i in range(idx1 + 1, len(eigvals)):
            if eigvals[i] > zero_tol and (eigvals[i] - eigvals[idx1]) >= gap_tol:
                idx2 = i
                break
        
        if idx2 is None:
            print(f"[Warning] Only found 1 eigenvalue > zero_tol with sufficient gap, returning None.")
            return None, None, None
        
        
        lambda1 = eigvals[idx1].item()
        lambda2 = eigvals[idx2].item()
        
        
        if idx1 > 0:
            v1 = eigvecs[:, :idx1].sum(dim=1)
        else:
            v1 = eigvecs[:, 0]

        return lambda1, lambda2, v1
    
    def compute_min_two_eigs_via_kron(self,
                                      f_l_minus_1_avg: torch.Tensor,
                                      J_l_avg: torch.Tensor,
                                      gamma: float = 1.0,
                                      K: float = 1.0,
                                      zero_tol: float = 1e-8,
                                      gap_tol: float = 5e-8) -> tuple:


        
        f = f_l_minus_1_avg.view(-1)
        device = f.device
        original_dtype = f.dtype

        
        f_norm = torch.norm(f)
        if f_norm <= zero_tol:
            print("[Warning] f norm is too small; cannot form meaningful C eigenpair.")
            return None, None, None
        lambda_c = (f_norm * f_norm) / K
        u_c = (f / f_norm).to(dtype=original_dtype)

        
        J = J_l_avg
        B = (J.T @ J)
        if gamma != 1.0:
            B = gamma * B
        B_reg = B.float()

        try:
            eigvals_B, eigvecs_B = torch.linalg.eigh(B_reg)
        except RuntimeError:
            
            eigvals_B, eigvecs_B = torch.linalg.eigh(B_reg + 1e-32 * torch.eye(B_reg.shape[0], device=B_reg.device, dtype=B_reg.dtype))

        eigvals_B = eigvals_B.to(original_dtype)
        eigvecs_B = eigvecs_B.to(original_dtype)

        
        abs_eigvals = eigvals_B.abs()
        sorted_indices = torch.argsort(abs_eigvals)
        eigvals_sorted = eigvals_B[sorted_indices]
        eigvecs_sorted = eigvecs_B[:, sorted_indices]

        
        idx1 = None
        for i in range(len(eigvals_sorted)):
            if eigvals_sorted[i].abs() > zero_tol:
                idx1 = i
                break
        if idx1 is None:
            print("[Warning] No non-zero eigenvalue found for B.")
            return None, None, None

        idx2 = None
        for i in range(idx1 + 1, len(eigvals_sorted)):
            if eigvals_sorted[i].abs() > zero_tol and (eigvals_sorted[i] - eigvals_sorted[idx1]).abs() >= gap_tol:
                idx2 = i
                break
        if idx2 is None:
            print("[Warning] Only one usable eigenvalue found for B.")
            return None, None, None

        mu1 = eigvals_sorted[idx1]
        mu2 = eigvals_sorted[idx2]
        w1 = eigvecs_sorted[:, idx1]

        
        lambda1 = (lambda_c * mu1).abs().item()
        lambda2 = (lambda_c * mu2).abs().item()

        
        v1 = torch.kron(u_c, w1)
        return lambda1, lambda2, v1
    
    def compute_all_jacobians_efficiently(self, final_output):
        
        
        
        layer_outputs = []
        layer_ids = []
        for layer_id in self.trainable_layers.keys():
            if layer_id in self.layer_outputs:
                layer_outputs.append(self.layer_outputs[layer_id])
                layer_ids.append(layer_id)
        
        if not layer_outputs:
            return
        
        
        try:
            
            all_jacobians_list = []
            
            for layer_output in layer_outputs:
                
                layer_output=layer_output.requires_grad_(True)
                jacobian_per_class = []
                for class_idx in range(final_output.shape[1]):  
                    
                    grad = torch.autograd.grad(
                        outputs=final_output[:, class_idx].sum(),  
                        inputs=layer_output,
                        create_graph=False,
                        retain_graph=True,
                        allow_unused=True
                    )[0]
                    
                    if grad is not None:
                        
                        grad_avg = grad.mean(dim=0)  
                        jacobian_per_class.append(grad_avg)
                    else:
                        jacobian_per_class.append(torch.zeros(layer_output.shape[1], device=layer_output.device))
                
                if jacobian_per_class:
                    
                    jacobian = torch.stack(jacobian_per_class, dim=0)
                    all_jacobians_list.append(jacobian)
                else:
                    all_jacobians_list.append(None)
            
            torch.cuda.empty_cache() 
            all_jacobians = all_jacobians_list
        except Exception as e:
            all_jacobians = [None] * len(layer_outputs)
        
        
        for i, layer_id in enumerate(layer_ids):
            if all_jacobians[i] is not None:
                
                J_l_avg = all_jacobians[i]  
                
                
                if layer_id not in self.J_l_avg:
                    self.J_l_avg[layer_id] = J_l_avg
                else:
                    
                    current_avg = self.J_l_avg[layer_id]
                    count = self.batch_count[layer_id]
                    self.J_l_avg[layer_id] = current_avg + (J_l_avg - current_avg) / count
    
    def calculate_stable_rank_per_layer(self, 
                                        g: torch.Tensor, 
                                        lambda1: float, 
                                        lambda2: float, 
                                        v1: torch.Tensor,
                                        eta: float = 0.01,
                                        epsilon: float = 1e-8) -> float:


        if v1 is None or lambda1 is None or lambda2 is None:
            print("[Warning] Missing eigenvalue or eigenvector, using parameter count as stable rank.")
            
            return g.numel()

        
        v1_norm = torch.norm(v1)
        if v1_norm < epsilon:
            print("[Warning] v1 is a zero vector, using parameter count as stable rank.")
            return g.numel()
        v1_unit = v1 / v1_norm

        
        g_norm = torch.norm(g)
        
        
        
        g = g / g_norm

        
        dot_product = torch.dot(g, v1_unit)
        g_parallel = dot_product * v1_unit

        
        g_perpendicular = g - g_parallel

        
        norm_g_parallel_sq = torch.norm(g_parallel, 2)**2
        norm_g_perpendicular_sq = torch.norm(g_perpendicular, 2)**2

        
        
        
        
        
        
        ratio_g = norm_g_perpendicular_sq / norm_g_parallel_sq
        print(f"{ratio_g.item()=}")

        
        ratio_lambda = (lambda1/lambda2) * ((2 - eta * lambda1)/(2 - eta * lambda2))

        
        stable_rank = 1 + ratio_lambda * ratio_g
        
        
        if not torch.isfinite(stable_rank):
            print("[Warning] stable_rank is not finite, using parameter count as stable rank.")
            return g.numel()
        return min(stable_rank, g.numel())

    def compute_all_layer_stable_ranks(self, input_data, eta: float = 0.01):


        
        
        if isinstance(input_data, dict):
            output = self.model(**input_data)
        else:
            output = self.model(input_data)
        
        
        if isinstance(output, dict) and 'logits' in output:
            logits = output['logits']
        else:
            logits = output
        
        
        self.compute_all_jacobians_efficiently(logits)
        
        
        self.layer_stable_ranks = {}
        
        
        for layer_id in self.trainable_layers.keys():
            if layer_id in self.f_l_minus_1_avg and layer_id in self.J_l_avg:
                f_l_minus_1_avg = self.f_l_minus_1_avg[layer_id]
                J_l_avg = self.J_l_avg[layer_id]
                
                
                g_avg = self._get_layer_gradient_avg(layer_id)
                g = g_avg.flatten()  
                
                
                S = self.compute_S_matrix_batch_avg(f_l_minus_1_avg, J_l_avg)
                lambda1, lambda2, v1 = self.compute_min_two_eigenvalues(S)
                
                
                
                
                
                
                
                
                if lambda1 is None or lambda2 is None or v1 is None:
                    print(f"[Warning] Skipping layer {layer_id} due to insufficient eigenvalues, using parameter count as stable rank.")
                    stable_rank = g.numel()
                else:
                    
                    if lambda1 < 0:
                        lambda1 = -lambda1
                        v1 = -v1
                    if lambda2 is not None and lambda2 < 0:
                        lambda2 = -lambda2
                    
                    
                    stable_rank = self.calculate_stable_rank_per_layer(g, lambda1, lambda2, v1, eta)
                
                
                self.layer_stable_ranks[layer_id] = {
                    'stable_rank': stable_rank,
                    
                    
                    
                    
                    
                    
                }
        self.model.zero_grad()
    

    
    def _get_layer_weight(self, layer_id: int) -> torch.Tensor:
        
        weight_param_name = self.trainable_layers[layer_id]['weight_param_name']
        
        for name, param in self.model.named_parameters():
            if name == weight_param_name:
                return param.data
        



    def _get_layer_gradient(self, layer_id: int) -> torch.Tensor:
        
        weight_param_name = self.trainable_layers[layer_id]['weight_param_name']
        
        for name, param in self.model.named_parameters():
            if name == weight_param_name:
                if param.grad is not None:
                    return param.grad.data.flatten()
                else:
                    print(f"[Warning] Gradient is None for layer {layer_id}")
                    return torch.zeros_like(param.data.flatten())
        
        
    
    def _get_layer_gradient_avg(self, layer_id: int) -> torch.Tensor:
        
        if layer_id in self.grad_l_avg:
            return self.grad_l_avg[layer_id]  
        else:
            print(f"[Warning] No averaged gradient found for layer {layer_id}")
            return None
    
    def cleanup(self):
        
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
        
        
        self.f_l_minus_1_avg.clear()
        self.J_l_avg.clear()
        self.grad_l_avg.clear() 
        self.batch_count.clear()
        self.trainable_layers.clear()
        self.layer_outputs.clear()
        self.layer_inputs.clear() 



class DynamicMeanTracker:
    
    
    def __init__(self):
        
        self.mean = None
        self.count = 0
        
    def smooth(self, diff: float) -> float:
        
        
        
        return 2.0 * torch.sigmoid(torch.tensor(diff)).item()
        
    def update(self, new_value: float) -> float:
        
        
        if self.mean is None:
            
            self.mean = new_value
            self.count = 1
        else:
            
            self.count += 1
            self.mean = (self.mean * (self.count - 1) + new_value) / self.count
        
        
        diff = new_value - self.mean
        
        
        return self.smooth(diff)
    
    def get_mean(self) -> float:
        
        return self.mean if self.mean is not None else 0.0
    
    def get_count(self) -> int:
        
        return self.count
    
    def reset(self):
        
        self.mean = None
        self.count = 0




class StableRankCallback(transformers.TrainerCallback):
    def __init__(self, calculator_class, trainer=None, device='cuda'):
        self.calculator_class = calculator_class
        self.device = device
        self.calculator = None
        self.ratio_history = []
        self.dummy_input = None
        self.dummy_target = None
        self._trainer = trainer  
        self.mean_tracker = DynamicMeanTracker()  
        
    def on_train_begin(self, args, state, control, **kwargs):
        
        model = kwargs.get('model', None)
        if model is not None:
            
            self.calculator = self.calculator_class(model, last_k_layers=1)
            self.calculator.register_hooks()
            
            
            try:
                
                train_dataloader = self._trainer.get_train_dataloader()
                
                
                for batch in train_dataloader:
                    
                    self.dummy_input = batch
                    
                    if isinstance(batch, dict):                        
                        _ = self.dummy_input.pop("weight")
                        _ = self.dummy_input.pop("sample_idx")
                        
                        for key, value in batch.items():
                            if isinstance(value, torch.Tensor):
                                print(f"  {key}: {value.shape} ({value.dtype})")
                    break
                    
            except Exception as e:
                
                
                self._create_dummy_input_from_model(model, args.device)
    
    def _create_dummy_input_from_model(self, model, device):
        
        
        first_linear = None
        last_linear = None
        for module in model.modules():
            if isinstance(module, torch.nn.Linear):
                if first_linear is None:
                    first_linear = module
                last_linear = module
        
        if first_linear is not None and last_linear is not None:
            input_dim = first_linear.in_features
            
            
            self.dummy_input = torch.randn(4, input_dim).to(device)  
        else:
            
            raise ValueError("Cannot find Linear layers in the model to determine input/output dimensions")
    
    def on_epoch_end(self, args, state, control, **kwargs):
        
        model = kwargs.get('model', None)
        if model is None or self.calculator is None:
            return
            
        try:
            
            
            
            self.calculator.compute_all_layer_stable_ranks(
                self.dummy_input, 
                eta=args.learning_rate
            )
            
            
            if self.calculator.layer_stable_ranks:
                last_layer_id = max(self.calculator.layer_stable_ranks.keys())
                stable_rank = self.calculator.layer_stable_ranks[last_layer_id]['stable_rank']
                
                
                last_layer_info = self.calculator.trainable_layers[last_layer_id]
                param_count = last_layer_info['module'].weight.numel()
                ratio = stable_rank / param_count
                
                
                if torch.distributed.is_initialized():
                    ratio_tensor = torch.tensor(ratio, dtype=torch.float32, device=self.device)
                    torch.distributed.all_reduce(ratio_tensor, op=torch.distributed.ReduceOp.MIN)
                    ratio = ratio_tensor.item()
                
                if isinstance(ratio, torch.Tensor):
                    ratio = ratio.item()
                
                self.ratio_history.append(ratio)
                
                
                self._update_sampler_ratio(ratio)
                
                print(f"Epoch {state.epoch}: Last layer stable rank ratio = {ratio:.6f}")
            else:
                print(f"[Warning] No stable rank computed for epoch {state.epoch}")
            
            
            self._reset_hook_data()
            
        except Exception as e:
            print(f"[Error] Failed to compute stable rank: {e}")
            import traceback
            traceback.print_exc()
            return
    
    def _reset_hook_data(self):
        
        if self.calculator is not None:
            
            self.calculator.f_l_minus_1_avg.clear()
            self.calculator.J_l_avg.clear()
            self.calculator.grad_l_avg.clear()
            self.calculator.batch_count.clear()
            self.calculator.layer_outputs.clear()
            self.calculator.layer_inputs.clear()
            self.calculator.layer_stable_ranks.clear()
            

    
    def _update_sampler_ratio(self, factor):
        
        if self._trainer is not None:
            
            smoothed_factor = self.mean_tracker.update(factor)
            
            self._trainer.dynamic_sampling_factor = smoothed_factor
    
    def on_train_end(self, args, state, control, **kwargs):
        
        if self.calculator is not None:
            self.calculator.cleanup()
        
        
        if self._is_main_process() and self.ratio_history and args.output_dir:            
            
            output_data = {
                "stable_rank_ratio": [
                    {"epoch": i + 1, "sr_ratio": ratio} 
                    for i, ratio in enumerate(self.ratio_history)
                ]
            }
            
            
            os.makedirs(args.output_dir, exist_ok=True)
            
            
            output_file = os.path.join(args.output_dir, "sr_ratio.json")
            with open(output_file, 'w') as f:
                json.dump(output_data, f, indent=2)
            
            print(f"Stable rank ratio history saved to {output_file}")

    def _is_main_process(self):
        
        if torch.distributed.is_initialized():
            return torch.distributed.get_rank() == 0
        return True


 