import math

from functools import partial
from typing import Callable, Optional

import torch
from torch import nn
import torch.nn.functional as F

from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.processing_utils import Unpack
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import (
    add_start_docstrings_to_model_forward,
    can_return_tuple,
)
from transformers.models.llama.modeling_llama import (
    LlamaMLP,
    LlamaAttention,
    LlamaModel,
    LlamaConfig,
    LLAMA_INPUTS_DOCSTRING,
    repeat_kv,
    apply_rotary_pos_emb,
    eager_attention_forward,
    logger,
)

from dataclasses import dataclass, field


class LookaheadKVCluster:
    def __init__(self, lookahead_size=8, window_size=32, max_capacity_prompt=128, kernel_size=5, pooling="mean", reduction="mean"):
        self.lookahead_size = lookahead_size
        self.window_size = window_size
        self.max_capacity_prompt = max_capacity_prompt
        self.kernel_size = kernel_size
        self.pooling = pooling
        self.reduction = reduction

    def reset(self, lookahead_size=8, window_size=32, max_capacity_prompt=128, kernel_size=5, pooling="mean", reduction="mean"):
        self.lookahead_size = lookahead_size
        self.window_size = window_size
        self.max_capacity_prompt = max_capacity_prompt
        self.kernel_size = kernel_size
        self.pooling = pooling
        self.reduction = reduction

    def update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):
        # Divide projection states
        query_states_l = query_states[:, :, -self.lookahead_size:, :]
        key_states_l = key_states[:, :, -self.lookahead_size:, :]
        
        key_states_o = key_states[:, :, :-self.lookahead_size, :]
        value_states_o = value_states[:, :, :-self.lookahead_size, :]

        # Get shape info
        bsz, num_heads, lookahead_size, head_dim = query_states_l.shape
        _, num_kv_heads, seq_len, _ = key_states_o.shape
        if seq_len < self.max_capacity_prompt:
            return key_states_o, value_states_o
        else:
            key_states_all = torch.cat([key_states_o, key_states_l], dim=-2)

            num_queries_in_group = num_heads // num_kv_heads
            key_states_repeat = repeat_kv(key_states_all, num_queries_in_group)

            attn_weights = torch.matmul(query_states_l, key_states_repeat.transpose(2, 3)) / math.sqrt(head_dim)
            mask = torch.full((lookahead_size, lookahead_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
            mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
            mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
            mask = mask.to(attn_weights.device)
            attention_mask = mask[None, None, :, :]
            attn_weights[:, :, -lookahead_size:, -lookahead_size:] += attention_mask

            attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states_l.dtype)
            attn_weights_sum = attn_weights[:, :, -lookahead_size:, : -lookahead_size].sum(dim = -2)

            if self.reduction == "mean":
                attn_weights_sum = attn_weights_sum.view((bsz, num_kv_heads, num_queries_in_group, -1)).mean(dim=-2)
            elif self.reduction == "max":
                attn_weights_sum = attn_weights_sum.view((bsz, num_kv_heads, num_queries_in_group, -1)).max(dim=-2).values
            else:
                raise ValueError('Reduction method not supported')

            # Apply 1D pooling
            if self.pooling == 'avgpool':
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            elif self.pooling == 'maxpool':
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            else:
                raise ValueError('Pooling method not supported')

            indices = attn_cache.topk(self.max_capacity_prompt, dim=-1).indices
            indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            key_states_compress = key_states_o.gather(dim = 2, index = indices)
            value_states_compress = value_states_o.gather(dim = 2, index = indices)

            return key_states_compress, value_states_compress


def init_lookaheadkv_cluster(self):
    for attr in ["lookahead_size", "max_capacity_prompt", "kernel_size", "pool_type", "reduction_type"]:
        if not hasattr(self.config, attr):
            raise ValueError(f"Config does not contain attribute: {attr}")

    self.lookaheadkv_cluster = LookaheadKVCluster( 
        lookahead_size = self.config.lookahead_size,
        max_capacity_prompt = self.config.max_capacity_prompt,
        kernel_size = self.config.kernel_size,
        pooling = self.config.pool_type,
        reduction = self.config.reduction_type,
    )


class LlamaLoRA(nn.Module):  
    def __init__(  
        self,
        input_dim,
        output_dim,
        rank: int = 8,
        lora_alpha: int = 32,
        lora_dropout: float = 0.0,
    ):
        super().__init__()
        self.rank = rank
        self.lora_alpha = lora_alpha   # scaling factor
        self.scaling = self.lora_alpha / self.rank

        self.lora_A = nn.Parameter(torch.randn(input_dim, rank) * 0.01)
        self.lora_B = nn.Parameter(torch.zeros(rank, output_dim))
        self.dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0.0 else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        lora_out = self.dropout(x) @ self.lora_A
        lora_out = lora_out @ self.lora_B
        result = self.scaling * lora_out

        return result
    

class LlamaMLPLookaheadKV(LlamaMLP):
    def __init__(self, config):
        super().__init__(config)

        self.lookahead_size = config.mask_num
        self.rank = config.lora_rank
        if "gate" in self.config.lora:
            self.gate_proj_lora = LlamaLoRA(config.hidden_size, config.intermediate_size, rank=self.rank)
        if "up" in self.config.lora:
            self.up_proj_lora = LlamaLoRA(config.hidden_size, config.intermediate_size, rank=self.rank)
        if "down" in self.config.lora:
            self.down_proj_lora = LlamaLoRA(config.intermediate_size, config.hidden_size, rank=self.rank)

    def forward(self, x):
        x_gate = self.gate_proj(x)
        x_up = self.up_proj(x)
        if x.shape[1] > 1:
            if 'gate' in self.config.lora:
                x_gate_l = self.gate_proj_lora(x[:, -self.lookahead_size:, :])
                x_gate[:, -self.lookahead_size:, :] += x_gate_l
            if 'up' in self.config.lora:
                x_up_l = self.up_proj_lora(x[:, -self.lookahead_size:, :])
                x_up[:, -self.lookahead_size:, :] += x_up_l

        x_act = self.act_fn(x_gate) * x_up
        down_proj = self.down_proj(x_act)
        if x.shape[1] > 1 and 'down' in self.config.lora:
            x_down_l = self.down_proj_lora(x_act[:, -self.lookahead_size:, :])
            down_proj[:, -self.lookahead_size:, :] += x_down_l

        return down_proj


class LlamaAttentionLookaheadKV(LlamaAttention):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__(config, layer_idx)

        self.lookahead_size = config.mask_num
        self.rank = config.lora_rank
        if "q" in self.config.lora:
            self.q_proj_lora = LlamaLoRA(config.hidden_size, config.num_attention_heads * config.head_dim, rank=self.rank)
        if "k" in self.config.lora:
            self.k_proj_lora = LlamaLoRA(config.hidden_size, config.num_key_value_heads * config.head_dim, rank=self.rank)
        if "v" in self.config.lora:
            self.v_proj_lora = LlamaLoRA(config.hidden_size, config.num_key_value_heads * config.head_dim, rank=self.rank)
        if "o" in self.config.lora:
            self.o_proj_lora = LlamaLoRA(config.hidden_size, config.hidden_size, rank=self.rank)
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        # [LKV] register kv_cluster
        init_lookaheadkv_cluster(self)

        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        # [LKV]  LoRA
        query_states = self.q_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        key_states = self.k_proj(hidden_states)

        if hidden_states.shape[1] > 1:
            if 'q' in self.config.lora:
                query_states_l = self.q_proj_lora(hidden_states[:, -self.config.lookahead_size:, :])
                query_states[:, -self.lookahead_size:, :] += query_states_l
            if 'k' in self.config.lora:
                key_states_l = self.k_proj_lora(hidden_states[:, -self.config.lookahead_size:, :])
                key_states[:, -self.lookahead_size:, :] += key_states_l
            if 'v' in self.config.lora:
                value_states_l = self.v_proj_lora(hidden_states[:, -self.config.lookahead_size:, :])
                value_states[:, -self.lookahead_size:, :] += value_states_l

        query_states = query_states.view(hidden_shape).transpose(1, 2)
        value_states = value_states.view(hidden_shape).transpose(1, 2)
        key_states = key_states.view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            if query_states.shape[2] > 1:
                # if ratio is given, dynamically adjust here
                if self.config.max_capacity_prompt_ratio > 0.0:
                    self.config.max_capacity_prompt = round(query_states.shape[2] * self.config.max_capacity_prompt_ratio)
                    self.lookaheadkv_cluster.max_capacity_prompt = round(query_states.shape[2] * self.config.max_capacity_prompt_ratio)

                key_states_compress, value_states_compress = self.lookaheadkv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)
                # [LKV] get position embedding of prompt (without lookahead tokens)
                cos, sin = (
                    position_embeddings[0][:, :-self.config.lookahead_size, :],
                    position_embeddings[1][:, :-self.config.lookahead_size, :],
                )
                cache_position_o = cache_position[:-self.config.lookahead_size]
                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position_o}
                past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
            else:
                cache_position_o = cache_position - self.config.lookahead_size
                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position_o}
                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        attn_output_pre = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output_pre)

        if hidden_states.shape[1] > 1:
            if 'o' in self.config.lora:
                out_states_l = self.o_proj_lora(attn_output_pre[:, -self.config.lookahead_size:, :])
                attn_output[:, -self.lookahead_size:, :] += out_states_l
        return attn_output, attn_weights
    

class LlamaModelLookaheadKV(LlamaModel):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        assert hasattr(config, "mask_num") or hasattr(config, "lookahead_size"), "To run LookaheadKV, you should specify the lookahead_size (mask_num)."
        self.lookahead_size = config.lookahead_size if hasattr(config, "lookahead_size") else config.mask_num
        self.embed_lookahead = nn.Embedding(self.lookahead_size, config.hidden_size)
    
    @can_return_tuple
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
    ) -> BaseModelOutputWithPast:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        if not isinstance(past_key_values, (type(None), Cache)):
            raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
        
        # Determine if we're in prefill or decoding phase
        is_prefill = (
            not self.training
            and past_key_values is None
            or past_key_values.get_seq_length() == 0
        )

        if is_prefill:
            return self._forward_lookahead(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                cache_position=cache_position,
                **flash_attn_kwargs,
            )
        else:
            return self._forward(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                cache_position=cache_position,
                **flash_attn_kwargs,
            )
    
    def _forward_lookahead(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
    ) -> BaseModelOutputWithPast:
        
        bsz, input_len = input_ids.size() 
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # [LKV] Add lookahead embedding
        lookahead_ids = torch.tensor([[i for i in range(self.lookahead_size)] for b in range(bsz)]).to(input_ids.device)
        inputs_embeds_la = self.embed_lookahead(lookahead_ids)
        inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_la], dim=1)

        # [LKV] Modify position ids, attention_mask, cache_position, etc.
        if position_ids is not None:
            position_ids_la = torch.stack([torch.arange(start, start + self.lookahead_size) for start in position_ids[:, -1]], dim=0).to(input_ids.device)
            position_ids = torch.cat([position_ids, position_ids_la], dim=1)

        if attention_mask is not None:
            attention_mask_la = torch.tensor([[1 for _ in range(self.lookahead_size)] for b in range(bsz)]).to(input_ids.device)
            attention_mask = torch.cat([attention_mask, attention_mask_la], dim=1)
        
        if cache_position is not None:
            cache_position_la = torch.arange(cache_position[-1], cache_position[-1] + self.lookahead_size).to(input_ids.device)
            cache_position = torch.cat([cache_position, cache_position_la], dim=0)


        if use_cache and past_key_values is None:
            past_key_values = DynamicCache()

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    partial(decoder_layer.__call__, **flash_attn_kwargs),
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                    position_embeddings,
                    self.lookahead_size,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    lookahead_size=self.lookahead_size,
                    **flash_attn_kwargs,
                )

            hidden_states = layer_outputs[0]

            if output_attentions:
                # [LKV] remove attention of lookahead
                all_self_attns += (layer_outputs[1],)

        # [LKV] remove lookahead embeddings from last hidden states for decoding
        hidden_states = hidden_states[:, :-self.lookahead_size]
        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )
    
    def _forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
    ) -> BaseModelOutputWithPast:
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache()

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    partial(decoder_layer.__call__, **flash_attn_kwargs),
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                    position_embeddings,
                    0,  # Lookahead size
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    lookahead_size=0,  # Lookahead size
                    **flash_attn_kwargs,
                )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )


def update_llama_model_for_lookaheadkv(model):
    # Attention Forward
    for i in range(len(model.model.layers)):
        model.model.layers[i].self_attn.forward = partial(
            LlamaAttentionLookaheadKV.forward,
            model.model.layers[i].self_attn,
        )

    return model


def reset_llama_model(model):
    model.model.forward = partial(
        LlamaModel.forward,
        model.model,
    )
    
    for i in range(len(model.model.layers)):
        attn_layer = model.model.layers[i].self_attn
        attn_layer.forward = LlamaAttention.forward.__get__(
            attn_layer,
            type(attn_layer),
        )

    return model



@dataclass
class LookaheadKVConfig:
    """
    Configuration for LookaheadKV.
    """

    max_capacity_prompt: int = field(
        default=1024,
        metadata={"help": "Maximum capacity of the KV cache/prompt."}
    )

    max_capacity_prompt_ratio: int = field(
        default=-1.0,
        metadata={"help": "Maximum capacity of the KV cache/prompt."}
    )

    window_size: int = field(
        default=64,
        metadata={"help": "Window size (final tokens of input prompt)."}
    )

    pool_type: str = field(
        default="max",
        metadata={
            "help": "Type of pooling to use. Options: 'max', 'avg'."
        }
    )

    kernel_size: int = field(
        default=64,
        metadata={"help": "Pool kernel size."}
    )

    reduction_type: str = field(
        default="max",
        metadata={
            "help": "Type of reduction to use. Options: 'max', 'mean'."
        }
    )

    lookahead_size: Optional[int] = field(
        default=None,
        metadata={
            "help": "Number of lookahead tokens to consider for LookaheadKV."
        }
    )

def lookaheadkv_patch_model(model, config: LookaheadKVConfig):
    """
    Patch the model for LookaheadKV.
    """

    assert config.kernel_size % 2 == 1

    model.config.lookahead_size = config.lookahead_size
    model.config.window_size = config.window_size
    model.config.max_capacity_prompt = config.max_capacity_prompt
    model.config.max_capacity_prompt_ratio = config.max_capacity_prompt_ratio
    model.config.reduction_type = config.reduction_type
    model.config.pool_type = config.pool_type
    model.config.kernel_size = config.kernel_size

    return model