import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import evaluate




def print_trainable_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total Parameters: {total_params}")
    print(f"Trainable Parameters: {trainable_params} ({100 * trainable_params / total_params}%)")



def get_parent_module(model, module_name):
        names = module_name.split(".")
        parent = model
        for name in names[:-1]:
            parent = getattr(parent, name)
        return parent, names[-1]

def reduced_svd(model, energy_ratio, r_type="none"):
    low_rank_list = []
    for name, layer in model.model.layers.named_children():
        low_rank = []
        
        Uq = layer.self_attn.q_adapt.proj[-1]
        Uk = layer.self_attn.k_adapt.proj[-1]
        Uv = layer.self_attn.v_adapt.proj[-1]
        UvT = layer.self_attn.o_adapt.proj[0]

        u, s, vh = torch.linalg.svd(Uq)
        u_, s_, vh_ = torch.linalg.svd(Uk)
        total_energy = (s**2).sum(dim=1, keepdim=True)
        energy = torch.cumsum(s**2, dim=1) / total_energy
        k = (energy >= energy_ratio).int().argmax(dim=1) + 1
        max_k = k.max().item()
        total_energy = (s_**2).sum(dim=1, keepdim=True)
        energy = torch.cumsum(s_**2, dim=1) / total_energy
        k = (energy >= energy_ratio).int().argmax(dim=1) + 1
        max_k = max(max_k, k.max().item())
        print(f"{name} / Uq & Uk :{max_k}")
        low_rank.append(max_k)
        
        u, s, vh = torch.svd(layer.self_attn.q_adapt.proj[-1])
        ur = u[:, :, :]
        sr = torch.diag_embed(s[:, :])[:, :, :max_k]
        vhr = vh[:, :max_k, :max_k]
        layer.self_attn.q_adapt.proj[-1] = torch.nn.Parameter(ur @ sr @ vhr.mT)
        
        u_, s_, vh_ = torch.svd(layer.self_attn.k_adapt.proj[-1])
        ur = u_[:, :, :]
        sr = torch.diag_embed(s_[:, :])[:, :, :max_k]
        vhr = vh_[:, :max_k, :max_k]
        layer.self_attn.k_adapt.proj[-1] = torch.nn.Parameter(ur @ sr @ vhr.mT)
        
        u, s, vh = torch.linalg.svd(Uv)
        u_, s_, vh_ = torch.linalg.svd(UvT)
        total_energy = (s**2).sum(dim=1, keepdim=True)
        energy = torch.cumsum(s**2, dim=1) / total_energy
        k = (energy >= energy_ratio).int().argmax(dim=1) + 1
        max_k = k.max().item()
        total_energy = (s_**2).sum(dim=1, keepdim=True)
        energy = torch.cumsum(s_**2, dim=1) / total_energy
        k = (energy >= energy_ratio).int().argmax(dim=1) + 1
        max_k = max(max_k, k.max().item())
        print(f"{name} / Uv & UvT :{max_k}")
        low_rank.append(max_k)

        u, s, vh = torch.svd(layer.self_attn.v_adapt.proj[-1])
        ur = u[:, :, :]
        sr = torch.diag_embed(s[:, :])[:, :, :max_k]
        vhr = vh[:, :max_k, :max_k]
        
        layer.self_attn.v_adapt.proj[-1] = torch.nn.Parameter(ur @ sr @ vhr.mT)
        u_, s_, vh_ = torch.svd(layer.self_attn.o_adapt.proj[0])
        ur = u_[:, :max_k, :max_k]
        sr = torch.diag_embed(s_[:, :])[:, :max_k, :]
        vhr = vh_[:, :, :]
        layer.self_attn.o_adapt.proj[0] = torch.nn.Parameter(ur @ sr @ vhr.mT)
        
        low_rank_list.append(low_rank)

    return model, low_rank_list

