from functools import partial
from typing import Iterable, Tuple
import torch.nn as nn

from ..models.primitives import DEFAULT_PRIMITIVES

def iter_vit(backbone: nn.Module, mode: str, primitives=DEFAULT_PRIMITIVES) -> Iterable[Tuple[nn.Module, nn.Module]]:

    stat_layers = []
    memory_layers = []
    layer_names = []
    primitives_per_layer = []
    dims = []
    
    for i, block in enumerate(backbone.blocks):
        stat_layers.append(block.attn.proj)
        memory_layers.append(block.attn.proj)
        layer_names.append(f"proj{i}")
        primitives_per_layer.append(primitives)
        dims.append(block.attn.proj.out_features)

    if mode == "statistics":
        return zip(layer_names, stat_layers, dims)
    elif mode == "memory":
        return zip(layer_names, memory_layers, dims, primitives_per_layer)
    

def iter_vit_qkv(backbone: nn.Module, mode: str, cheem_component: str = "value", primitives=DEFAULT_PRIMITIVES) -> Iterable[Tuple[nn.Module, nn.Module]]:

    stat_layers = []
    memory_layers = []
    layer_names = []
    primitives_per_layer = []
    dims = []
    
    alias = {"query": "q", "key": "k", "value": "v"}

    for i, block in enumerate(backbone.blocks):
        layer = getattr(block.attn.separable_qkv, alias[cheem_component])
        stat_layers.append(layer)
        memory_layers.append(layer)
        layer_names.append(f"{cheem_component}{i}")
        primitives_per_layer.append(primitives)
        dims.append(layer.out_features)

    if mode == "statistics":
        return zip(layer_names, stat_layers, dims)
    elif mode == "memory":
        return zip(layer_names, memory_layers, dims, primitives_per_layer)
    

def iter_vit_ffn(backbone: nn.Module, mode: str, primitives=DEFAULT_PRIMITIVES) -> Iterable[Tuple[nn.Module, nn.Module]]:

    stat_layers = []
    memory_layers = []
    layer_names = []
    primitives_per_layer = []
    dims = []
    
    for i, block in enumerate(backbone.blocks):
        stat_layers.append(block.mlp)
        memory_layers.append(block.mlp)
        layer_names.append(f"mlp{i}")
        primitives_per_layer.append(primitives)
        dims.append(block.mlp.fc2.out_features)

    if mode == "statistics":
        return zip(layer_names, stat_layers, dims)
    elif mode == "memory":
        return zip(layer_names, memory_layers, dims, primitives_per_layer)
    

def iter_vit_mean_var(backbone: nn.Module, mode: str) -> Iterable[Tuple[nn.Module, nn.Module]]:

    stat_layers = []
    memory_layers = []
    layer_names = []
    primitives_per_layer = []
    memory_dims = []
    stat_dims = []
    
    for i, block in enumerate(backbone.blocks):
        stat_layers.append(block.attn.proj)
        memory_layers.append(block.attn.proj)
        layer_names.append(f"proj{i}")
        primitives_per_layer.append(["skip", "reuse", "adapt", "new"])
        memory_dims.append(block.attn.proj.out_features)
        stat_dims.append(1)

    if mode == "statistics":
        return zip(layer_names, stat_layers, stat_dims)
    elif mode == "memory":
        return zip(layer_names, memory_layers, memory_dims, primitives_per_layer)
    
def iter_ewc_vit(backbone: nn.Module):

    params = []
    param_names = []

    for param_name, param in backbone.named_parameters():
        if "attn.proj" in param_name:
            params.append(param)
            param_names.append(param_name)

    return zip(param_names, params)

def iter_ewc_attn_vit(backbone: nn.Module):

    params = []
    param_names = []

    for param_name, param in backbone.named_parameters():
        if "attn." in param_name:
            params.append(param)
            param_names.append(param_name)

    return zip(param_names, params)

def iter_supsup_vit(backbone: nn.Module):

    layers = []
    layer_names = []

    for i, block in enumerate(backbone.blocks):
        layers.append(block.attn.proj)
        layer_names.append(f"proj{i}")

    return zip(layer_names, layers)

def iter_shift_vit(backbone: nn.Module):

    layers = []
    layer_names = []
    layers.append(backbone.patch_embed)
    layer_names.append(f"proj_patch")
    for i, block in enumerate(backbone.blocks):
        layers.append(block.attn.qkv)
        layer_names.append(f"qkv{i}")
        layers.append(block.attn.proj)
        layer_names.append(f"proj{i}")
        layers.append(block.mlp.fc1)
        layer_names.append(f"mlp_up{i}")
        layers.append(block.mlp.fc2)
        layer_names.append(f"mlp_down{i}")

    return zip(layer_names, layers)

def iter_scale_vit(backbone: nn.Module):

    layers = []
    layer_names = []
    for i, block in enumerate(backbone.blocks):
        layers.append(block.attn.qkv)
        layer_names.append(f"qkv{i}")
        layers.append(block.attn.proj)
        layer_names.append(f"proj{i}")
        layers.append(block.mlp.fc1)
        layer_names.append(f"mlp_up{i}")
        layers.append(block.mlp.fc2)
        layer_names.append(f"mlp_down{i}")

    return zip(layer_names, layers)


iter_vit_query = partial(iter_vit_qkv, cheem_component="query")
iter_vit_key = partial(iter_vit_qkv, cheem_component="key")
iter_vit_value = partial(iter_vit_qkv, cheem_component="value")

BACKBONE_ITERATORS = {
    "attn_proj": iter_vit,
    "query": iter_vit_query,
    "key": iter_vit_key,
    "value": iter_vit_value,
    "ffn": iter_vit_ffn
}

DARTS_BACKBONE_ITERATORS = {
    "attn_proj": partial(iter_vit, primitives=["reuse", "adapt", "new"])
}