import random 
import torch

from utils.lora_utils import add_lora_adapters
from utils.model_utils import build_model
from utils.train_utils import train 

def truncate_model(model, dataset, rank, alpha, b_var, a_var): 
    copy, _ = build_model(dataset)
    model_name = model.config._name_or_path 
    if model_name == "google/vit-base-patch16-224-in21k" or model_name == 'roberta-large': 
        target_modules = ['query', 'value']
    elif model_name == "t5-base": 
        target_modules = ['SelfAttention.q', 'SelfAttention.v']
    else: 
        raise NotImplementedError()
    
    copy = add_lora_adapters(copy, rank, alpha, b_var, a_var, target_modules)
    for name, parameter in copy.named_parameters(): 
        if 'base_layer' in name and 'weight' in name: 
            # Get LoRA parameter names 
            name_A = name.replace('base_layer.weight', 'lora_A')
            name_B = name.replace('base_layer.weight', 'lora_B') 

            A = model.state_dict()[name_A].data[:rank, :]
            B = model.state_dict()[name_B].data[:, :rank]
            copy.state_dict()[name_A].data.copy_(A)
            copy.state_dict()[name_B].data.copy_(B)
        # Make sure to copy over classifer
        else: 
            if 'lora_A' not in name and 'lora_B' not in name: 
                # print(f"{name} is not in product dict keys")
                copy.state_dict()[name].data.copy_(model.state_dict()[name].data) 

    copy = copy.cuda()
    
    return copy 

def zero_pad(adapter, rank, pad_B=True): 
    if pad_B: 
        d, r = adapter.shape
        if r < rank:
            padded_adapter = torch.zeros(d, rank, dtype=adapter.dtype, device=adapter.device)
            padded_adapter[:, :r] = adapter
            adapter = padded_adapter
    else: 
        r, d = adapter.shape  
        if r < rank:
            padded_adapter = torch.zeros(rank, d, dtype=adapter.dtype, device=adapter.device)
            padded_adapter[:r, :] = adapter
            adapter = padded_adapter
    
    return adapter 

def flex_lora_model(model, dataset, rank, alpha, b_var, a_var, avg_prod_dict): 
    copy, _ = build_model(dataset)
    model_name = model.config._name_or_path 
    if model_name == "google/vit-base-patch16-224-in21k" or model_name == 'roberta-large': 
        target_modules = ['query', 'value']
    elif model_name == "t5-base": 
        target_modules = ['SelfAttention.q', 'SelfAttention.v']
    else: 
        raise NotImplementedError()
    
    copy = add_lora_adapters(copy, rank, alpha, b_var, a_var, target_modules)
    for name, parameter in copy.named_parameters(): 
        if name in avg_prod_dict.keys(): 
            # print(f"{name} in product dict keys")
            prod = avg_prod_dict[name]

            # Split average product back into separate matrices 
            U, S, Vh = torch.linalg.svd(prod, full_matrices=False)
            Sigma_r = torch.diag(S[:rank])
            U_r = U[:, :rank] 
            B = (U_r @ Sigma_r)
            A = Vh[:rank, :]

            name_A = name.replace('base_layer.weight', 'lora_A')
            name_B = name.replace('base_layer.weight', 'lora_B') 
            copy.state_dict()[name_A].data.copy_(A)
            copy.state_dict()[name_B].data.copy_(B)
        
        else: 
            if 'lora_A' not in name and 'lora_B' not in name: 
                # print(f"{name} is not in product dict keys")
                copy.state_dict()[name].data.copy_(model.state_dict()[name].data)
    copy = copy.cuda()

    return copy

""" Heterogeneous Ravan Auxiliary Functions """

def random_ranking(num_heads): 
    ranking = list(range(1, num_heads + 1))
    random.shuffle(ranking)

    return ranking

def weight_based_ranking(scaling_factors, heads): 
    num_heads = len(heads)
    weighted_norms = [sf * torch.norm(head, p='fro') for sf, head in zip(scaling_factors, heads)]

    sorted_indices = sorted(range(num_heads), key=lambda i: weighted_norms[i])
    ranking = [0] * num_heads

    for rank, idx in enumerate(sorted_indices, start=1):
        ranking[idx] = rank
    
    return ranking

def get_grad(model, param_name): 
    for name, param in model.named_parameters():
        if name == param_name:
            if param.grad is not None:
                return param.grad.data

def freeze_other_heads(model, num_heads, curr_idx): 
    for name, param in model.named_parameters():
        if f'lora_R_{curr_idx}' in name or f'lora_scaling_{curr_idx}' in name:
            param.requires_grad = True 
        else: 
            param.requires_grad = False

def unfreeze_all_heads(model): 
    for name, param in model.named_parameters(): 
        if 'lora' in name or 'scaling' in name: 
            param.requires_grad = True 

def get_grads(model, num_heads, trainloader, opt, dataset, tokenizer): 
    batch = [list(trainloader)[0]]
    grad_dict_weights = {}
    grad_dict_scaling = {}

    for i in range(num_heads): 
        freeze_other_heads(model, num_heads, i)

        train(model, trainloader, opt, dataset, tokenizer=tokenizer, steps=1, update=False, batch=batch)
        for name, parameter in model.named_parameters(): 
            if 'base_layer' in name and 'weight' in name: 
                name_R = name.replace('base_layer.weight', f'lora_R_{i}')
                name_scaling = name.replace('base_layer.weight', f'lora_scaling_{i}')

                if i == 0: 
                    # print(get_grad(model, name_R))
                    # print(get_grad(model, name_scaling))
                    grad_dict_weights[name] = [get_grad(model, name_R)]
                    grad_dict_scaling[name] = [get_grad(model, name_scaling)]
                else: 
                    grad_dict_weights[name].append(get_grad(model, name_R))
                    grad_dict_scaling[name].append(get_grad(model, name_scaling))
    
    return grad_dict_weights, grad_dict_scaling

def collect_rankings(model, num_heads, ranking_method, trainloader, opt, dataset, tokenizer): 
    rankings = {}
    for name, parameter in model.named_parameters(): 
        if 'base_layer' in name and 'weight' in name: 
            if ranking_method == 'random': 
                ranking = random_ranking(num_heads)
                rankings[name] = ranking
                
            elif ranking_method == 'weight': 
                scaling_factors = [] 
                heads = [] 
                for i in range(num_heads): 
                    name_R = name.replace('base_layer.weight', f'lora_R_{i}')
                    name_scaling = name.replace('base_layer.weight', f'lora_scaling_{i}') 
                    scaling_factors.append(model.state_dict()[name_scaling].data)
                    heads.append(model.state_dict()[name_R].data)
                ranking = weight_based_ranking(scaling_factors, heads)
                rankings[name] = ranking

    if ranking_method == 'grad': 
        grad_dict_weights, grad_dict_scaling = get_grads(model, num_heads, trainloader, opt, dataset, tokenizer)

        for name, parameter in model.named_parameters(): 
            if 'base_layer' in name and 'weight' in name:
                sf_grads = grad_dict_scaling[name]
                heads_grads = grad_dict_weights[name]
                ranking = weight_based_ranking(sf_grads, heads_grads)
                rankings[name] = ranking
    
    unfreeze_all_heads(model)
    # print(rankings)

    return rankings

def freeze_by_name(model, names): 
    for name, parameter in model.named_parameters(): 
        if name in names: 
            parameter.requires_grad = False 

def freeze_from_rankings(model, head_frac, num_heads, ranking_method, trainloader, opt, dataset, tokenizer): 
    rankings = collect_rankings(model, num_heads, ranking_method, trainloader, opt, dataset, tokenizer)
    heads_to_freeze = []
    for name, parameter in model.named_parameters(): 
        if 'base_layer' in name and 'weight' in name: 
            keep_heads = int(head_frac * num_heads) 

            for i in range(num_heads): 
                if rankings[name][i] <= (num_heads - keep_heads): 
                    name_R = name.replace('base_layer.weight', f'lora_R_{i}')
                    name_scaling = name.replace('base_layer.weight', f'lora_scaling_{i}') 
                    heads_to_freeze.append(name_R)
                    heads_to_freeze.append(name_scaling)
    
    freeze_by_name(model, heads_to_freeze)