"""
FlexLoRA Rank Allocator
Dynamic rank allocation based on importance estimation
"""
import math
import torch
import os
import csv
import time
import re
from typing import Dict, List, Optional


class RankAllocator:
    """
    Dynamic rank allocator for FlexLoRA
    
    Args:
        model: The model with LoRA layers
        lora_r: Initial LoRA rank
        target_rank: Average target rank per layer
        init_warmup: Initial warmup steps
        final_warmup: Final warmup steps before training ends
        mask_interval: Interval steps between rank adjustments
        beta1: EMA coefficient for importance
        beta2: EMA coefficient for uncertainty
        total_step: Total training steps
        target_total_rank: Total target rank across all layers
        k: Number of top layers to increase rank
        b: Number of bottom layers to decrease rank
        output_dir: Directory to save logs
        enable_scheduler: Enable dynamic b scheduler
        mode: Importance computation mode ('entropy', 'nuclear', 'frobenius')
    """
    
    def __init__(
        self,
        model,
        lora_r: int,
        target_rank: int,
        init_warmup: int,
        final_warmup: int,
        mask_interval: int,
        beta1: float = 0.85,
        beta2: float = 0.85,
        total_step: Optional[int] = None,
        target_total_rank: Optional[int] = None,
        k: int = 2,
        b: int = 4,
        output_dir: Optional[str] = None,
        enable_scheduler: bool = False,
        mode: str = "entropy"
    ):
        self.k = k
        self.b = b
        self.initial_b = b
        self.mode = mode
        self.enable_scheduler = enable_scheduler
        self.output_dir = output_dir
        self.ave_target_rank = target_rank
        self.target_rank = target_total_rank
        self.lora_init_rank = lora_r
        self.initial_warmup = init_warmup
        self.final_warmup = final_warmup
        self.mask_interval = mask_interval
        self.beta1 = beta1
        self.beta2 = beta2
        self.total_step = total_step
        self.model = model
        self.rank_pattern: Dict[str, int] = {}
        self.global_step = 0
        
        self.get_lora_param_name()
        
        # CSV logging
        self.csv_path = os.path.join(
            self.output_dir or ".", 
            f"rank_log_{int(time.time())}.csv"
        )
        
        # Extract and sort layer names
        def extract_layer_number(name):
            match = re.search(r'\.blocks\.(\d+)\.', name)
            return int(match.group(1)) if match else float('inf')
        
        self.rank_names = sorted(
            [n for n, _ in model.named_parameters() if "lora_E" in n],
            key=extract_layer_number
        )
        
        self._csv_header_written = os.path.exists(self.csv_path)
        
        assert 0 < self.beta1 < 1
        assert 0 < self.beta2 < 1

    def get_lora_param_name(self):
        """Get all LoRA parameter names"""
        self.name_set = set()
        self.total_rank = 0
        self.shape_dict = {}
        
        for n, p in self.model.named_parameters():
            if "lora_A" in n:
                name_mat = n.replace("lora_A", "%s")
                self.name_set.add(name_mat)
                self.total_rank += p.size(0)
                self.shape_dict[n] = p.shape
            if "lora_B" in n:
                self.shape_dict[n] = p.shape
        
        self.name_set = list(sorted(self.name_set))
        
        if self.target_rank is None:
            self.target_rank = self.ave_target_rank * len(self.name_set)

    def compute_matrix_importance(
        self, 
        name: str, 
        E: torch.Tensor, 
        mode: str = "entropy"
    ) -> float:
        """
        Compute importance of a LoRA matrix
        
        Args:
            name: Parameter name
            E: Singular value matrix
            mode: Computation mode
        
        Returns:
            Importance score
        """
        with torch.no_grad():
            if mode == "nuclear":
                return E.mean().item()
            
            elif mode == "frobenius":
                return E.norm(p='fro').item()
            
            elif mode == "entropy":
                p = E ** 2
                p = p / (p.sum() + 1e-8)
                rank = E.numel()
                entropy = -torch.sum(p * torch.log(p + 1e-8))
                entropy = entropy / math.log(rank) if rank > 1 else entropy
                return entropy.item()
            
            else:
                raise ValueError(f"Unknown mode: {mode}")

    def mask_to_target_rank(self, model, curr_rank):
        """Adjust ranks by pruning and growing"""
        lora_A_list = []
        lora_B_list = []
        lora_E_list = []
        lora_E_name_map = {}

        for n, p in model.named_parameters():
            if "lora_A" in n:
                lora_A_list.append(p)
            if "lora_B" in n:
                lora_B_list.append(p)
            if "lora_E" in n:
                lora_E_list.append(p)
                lora_E_name_map[p] = n

        importance_matrix_level_all = []
        importance_matrix_level_r_gt_1 = []
        valid_idx_r_gt_1 = []
        valid_idx_all = []

        for idx, (A, B, E) in enumerate(zip(lora_A_list, lora_B_list, lora_E_list)):
            name = lora_E_name_map[E]
            importance = self.compute_matrix_importance(name, E, mode=self.mode)
            importance_matrix_level_all.append(importance)
            valid_idx_all.append(idx)

            r = E.shape[0]
            if r > 1:
                importance_matrix_level_r_gt_1.append(importance)
                valid_idx_r_gt_1.append(idx)

        # Decrease: select b least important from r > 1
        importance_tensor_decrease = torch.tensor(importance_matrix_level_r_gt_1)
        decrease_idx = torch.topk(
            importance_tensor_decrease, 
            self.b, 
            largest=False
        ).indices.tolist()
        decrease_idx = [valid_idx_r_gt_1[i] for i in decrease_idx]

        # Increase: select b most important from all
        importance_tensor_increase = torch.tensor(importance_matrix_level_all)
        increase_idx = torch.topk(
            importance_tensor_increase, 
            self.b, 
            largest=True
        ).indices.tolist()
        increase_idx = [valid_idx_all[i] for i in increase_idx]

        # Decrease rank
        decreased_layers = []
        for i in decrease_idx:
            A = lora_A_list[i]
            B = lora_B_list[i]
            E = lora_E_list[i]

            r = E.shape[0]
            if r <= 1:
                continue

            # Find minimum energy direction
            min_energy_idx = int(torch.argmin(E))
            keep_indices = [j for j in range(r) if j != min_energy_idx]
            keep_indices = torch.tensor(keep_indices, dtype=torch.long, device=A.device)

            # Prune rank
            A_new = torch.nn.Parameter(A[keep_indices])
            B_new = torch.nn.Parameter(B[:, keep_indices])
            E_new = torch.nn.Parameter(E[keep_indices])

            lora_A_list[i] = A_new
            lora_B_list[i] = B_new
            lora_E_list[i] = E_new

            self._replace_param(model, A, A_new)
            self._replace_param(model, B, B_new)
            self._replace_param(model, E, E_new)
            
            decreased_layers.append((lora_E_name_map[E], r, r-1))

        # Increase rank
        increased_layers = []
        for i in increase_idx:
            A, B, E = lora_A_list[i], lora_B_list[i], lora_E_list[i]
            r = E.shape[0]
            hdim_a = A.shape[1]
            hdim_b = B.shape[0]
            
            # Initialize with same scale as original FlexLoRA
            new_a = torch.randn(1, hdim_a, device=A.device)*0.02
            new_b = torch.randn(hdim_b, 1, device=B.device)
            new_e = torch.zeros_like(E[0:1])
            
            A_new = torch.nn.Parameter(torch.cat([A, new_a], dim=0))
            B_new = torch.nn.Parameter(torch.cat([B, new_b], dim=1))
            E_new = torch.nn.Parameter(torch.cat([E, new_e], dim=0))
            
            lora_A_list[i] = A_new
            lora_B_list[i] = B_new
            lora_E_list[i] = E_new
            
            self._replace_param(model, A, A_new)
            self._replace_param(model, B, B_new)
            self._replace_param(model, E, E_new)
            
            increased_layers.append((lora_E_name_map[E], r, r+1))

        # Log current ranks
        for name in self.rank_names:
            if name in dict(model.named_parameters()):
                param = dict(model.named_parameters())[name]
                self.rank_pattern[name] = param.size(0)

        # Write to CSV
        if not self._csv_header_written:
            with open(self.csv_path, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(["step"] + self.rank_names)
            self._csv_header_written = True

        row = [self.global_step] + [
            self.rank_pattern.get(name, 0) for name in self.rank_names
        ]
        with open(self.csv_path, "a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(row)
        
        # Print adjustment summary
        # if decreased_layers:
        #     print(f"  [Decreased] {len(decreased_layers)} layers:")
        #     for name, old_r, new_r in decreased_layers:  # Show all
        #         print(f"    {name}: {old_r} -> {new_r}")
        # if increased_layers:
        #     print(f"  [Increased] {len(increased_layers)} layers:")
        #     for name, old_r, new_r in increased_layers:  # Show all
        #         print(f"    {name}: {old_r} -> {new_r}")
        
        # total_rank = sum(self.rank_pattern.values())
        # print(f"  Total rank: {total_rank}")

        return True

    def _replace_param(self, model, old_param, new_param):
        """Replace parameter in model"""
        for module_name, module in model.named_modules():
            for name, param in module.named_parameters(recurse=False):
                if param is old_param:
                    setattr(module, name, new_param)
                    return

    def update_and_mask(self, model, global_step):
        """Update rank allocation at current step"""
        self.global_step = global_step
        
        if global_step < self.total_step - self.final_warmup:
            if self.enable_scheduler:
                self._b_scheduler(global_step)
            
            if (global_step > self.initial_warmup and 
                (global_step - self.initial_warmup) % self.mask_interval == 0 and 
                self.b > 0):
                print(f"[FlexLoRA] Step={global_step}, b={self.b}, adjusting ranks...")
                return 0, self.mask_to_target_rank(model, 0)
        
        return 0, None

    def _b_scheduler(self, global_step):
        """Schedule b parameter over training"""
        initial_b = self.initial_b
        final_b = 0
        progress = (global_step - self.initial_warmup) / (
            self.total_step - self.final_warmup - self.initial_warmup
        )
        progress = min(max(progress, 0), 1)
        mul_coeff = progress ** 3
        self.b = round(initial_b + (final_b - initial_b) * mul_coeff)

    def get_rank_pattern(self):
        """Get current rank pattern"""
        return self.rank_pattern

    def set_total_step(self, total_step):
        """Set total training steps"""
        self.total_step = total_step
        assert self.total_step > self.initial_warmup + self.final_warmup
