import transformers
from transformers import AutoTokenizer, EsmForMaskedLM,  AutoModelForCausalLM, Trainer, TrainingArguments
from tokenizers import Tokenizer
from dataclasses import dataclass, field
from typing import  Optional
import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)


class ResBlock(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.linear = nn.Linear(hidden_size, hidden_size)
        #torch.nn.init.zeros_(self.linear.weight)
        self.act = nn.SiLU()

    def forward(self, x):
        return x + self.act(self.linear(x))

class WrapModelConfig(PretrainedConfig):
    def __init__(self, base_model_config = None, num_layers=1, **kwargs):
        super().__init__(**kwargs)
        self.base_model_config = base_model_config
        self.num_layers = num_layers


class WrapModel(PreTrainedModel):
    config_class = WrapModelConfig

    def __init__(self, model, num_layers=1):
        config = WrapModelConfig(base_model_config = model.config, num_layers = num_layers)
        hidden_size = model.config.hidden_size
        super().__init__(config)

        self.model = model
        self.assist_acc_head = nn.Sequential( *([ResBlock(hidden_size)] * num_layers), nn.Linear(hidden_size, 2) )

    def forward(self, input_ids = None, labels = None, **kwargs):
        return self.model(input_ids = input_ids, labels = labels, **kwargs)


 