import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import AutoTokenizer, AutoModelForCausalLM, PretrainedConfig, \
    AutoModelForSequenceClassification
import math
from typing import List, Optional, Tuple, Union
import torch.nn.functional as F
from transformers.activations import get_activation

from dataclasses import dataclass, field



@dataclass
class CASTConfig(PretrainedConfig):
    CAST_add_layer_norm_before_adapter: bool = False
    CAST_add_layer_norm_after_adapter: bool = True
    CAST_activation: str = ""
    CAST_hidden_size: int = 16
    CAST_dropout: float = 0.0
    CAST_side_local: bool = False
    CAST_act_zip: bool = False



def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    classifer_params = 0
    for name, param in model.named_parameters():

        if "score" in name:
            # print(name)
            classifer_params += param.numel()
            param.requires_grad = False
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    for name, param in model.LLM.named_parameters():
        if "classifier" in name:
            classifer_params += param.numel()
            param.requires_grad = True
        all_param += param.numel()
        if param.requires_grad:
            # print(name)
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )
    print(f"classifer:{classifer_params}")


def _get_submodules(model, key):
    parent = model.get_submodule(".".join(key.split(".")[:-1]))
    target_name = key.split(".")[-1]
    target = model.get_submodule(key)
    return parent, target, target_name


class Activations(nn.Module):
    def __init__(self, activation_type):
        super().__init__()
        self.f = get_activation(activation_type)

    def forward(self, x):
        return self.f(x)


class StandardAdapter(nn.Module):
    def __init__(self, 
                 in_features: int, 
                 adapter_size: int, 
                 activation= None, 
                 add_layer_norm_before_adapter=False,
                 add_layer_norm_after_adapter=False, 
                 dropout=0.0, 
                 bias=True):

        super(StandardAdapter, self).__init__()
        
        self.down_proj = nn.Linear(in_features, adapter_size, bias=bias)

        if activation is not None:
            self.activation = Activations(activation.lower())
        else:
            self.activation = None

        self.up_proj = nn.Linear(adapter_size, in_features, bias=bias)
        self.add_layer_norm_before = add_layer_norm_before_adapter
        self.add_layer_norm_after = add_layer_norm_after_adapter
        if self.add_layer_norm_before:
            self.pre_layer_norm = nn.LayerNorm(in_features)
        if self.add_layer_norm_after:
            self.post_layer_norm = nn.LayerNorm(in_features)
        

        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

       
        
    def forward(self, x):
        residual = x

        if self.add_layer_norm_before:
            x = self.pre_layer_norm(x)
        
        x = self.down_proj(x)
        x = self.activation(x)
        x = self.up_proj(x)

        x = self.dropout(x)

        if self.add_layer_norm_after:
            x = self.post_layer_norm(x)

        return residual + x
    
