"""
Attention conversion helpers
"""
from functools import partial
from tqdm import tqdm
import torch.nn as nn


def convert_attention(model: nn.Module, 
                      attention_config: dict, 
                      train_attention: bool = False):
    """
    Call to convert all attention layers
    """
    if attention_config.attention_type != 'flash_attention_2':
        layers = traverse_layers(model)
        for layer_idx, layer in enumerate(tqdm(layers, desc='Converting attentions...')):
            layer.self_attn = convert_llama_attention(
                layer, attention_config, layers, train_attention,
            )
            layer.self_attn.converted = True
    else:
        print(f'-> attention_config.attention_type is {attention_config.attention_type}; not converting attentions')
    return model


def toggle_attention(llama_model: nn.Module, train: bool = False):
    """
    Make attentions trainable if train is True
    -> Set train_attention = False when finetuning
    """
    for layer in traverse_layers(llama_model):
        layer.self_attn.train_attention = train
    return llama_model


def remove_base_attention(llama_model: nn.Module):
    """
    Remove teacher attention after distillation (if we keep it)
    """
    for layer in traverse_layers(llama_model):
        if getattr(layer.self_attn, 'base_attn', False):
            del layer.self_attn.base_attn
    return llama_model
        

def traverse_layers(model: nn.Module, verbose: bool = False):
    """
    Return list of model layers
    """
    try:
        layers = model.model.layers
        if verbose:
            print('-> Loading from model.model.layers')
    except AttributeError as e: # if base model
        if verbose:
            print(e)
        try:
            layers = model.layers
            if verbose:
                print('-> Loading from model.layers')
        except AttributeError as e1:  # If we make a PEFT model
            if verbose:
                print(e1)
            layers = model.base_model.model.model.layers
            if verbose:
                print('-> Loading from model.base_model.model.model.layers')
    return layers


def convert_llama_attention(layer: nn.Module,
                            attention_config: dict,
                            layers: list[nn.Module],  # list of layers
                            train_attention: bool = False):
    """
    Converts a single layer's attention layer as specified by attention_config
    """
    return get_attention(**attention_config)(
        base_attn=layer.self_attn,
        layer_idx=layer.self_attn.layer_idx,  # Transformers v4.36
        train_attention=train_attention,
    )


def get_attention(attention_type: str, **kwargs: any):
    """
    Get the attention class; 
    -> Supported attentions : KVLinC, KIVI, ResQ, Quarot, Gear
    """
    kwargs['attention_type'] = attention_type
    if attention_type == 'kv_linc':
        from .attention import KVLinCAttention
        return partial(KVLinCAttention, **kwargs)
    elif attention_type == 'kivi_attention':
        from .attention import KIVIAttention
        return partial(KIVIAttention, **kwargs)
    elif attention_type == "quarot_attention":
        from .attention import QuarotAttention
        return partial(QuarotAttention, **kwargs)
    elif attention_type == "resq_attention":
        from .attention import ResQAttention
        return partial(ResQAttention, **kwargs)
    elif attention_type == "gear_attention":
        from .attention import GearAttention
        return partial(GearAttention, **kwargs)
    else:
        print(f'-> attention_type {attention_type} not handled... returning None')
        return None


def get_attention_cache(attention_type: str, past_key_values: any, kvquant_config: dict):
    """
    Determine how we store past keys and values when generating
    """
    if attention_type == 'kv_linc':
        from .attention import KVLinCAttentionCache
        return KVLinCAttentionCache(**kvquant_config)
    elif attention_type == "kivi_attention":
        from .attention import KIVIAttentionCache
        return KIVIAttentionCache(**kvquant_config)
    elif attention_type == "quarot_attention":
        from .attention import QuarotAttentionCache
        return QuarotAttentionCache(**kvquant_config)
    elif attention_type == "resq_attention":
        from .attention import ResQAttentionCache
        return ResQAttentionCache(**kvquant_config)
    elif attention_type == "gear_attention":
        from .attention import GearAttentionCache
        return GearAttentionCache(**kvquant_config)
    else:
        raise NotImplementedError