# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py
# with some refactor
'''import torch


class PrefixEncoder(torch.nn.Module):
    r"""
    The `torch.nn` model to encode the prefix.

    Args:
        config ([`PrefixTuningConfig`]): The configuration of the prefix encoder.

    Example:

    ```py
    >>> from peft import PrefixEncoder, PrefixTuningConfig

    >>> config = PrefixTuningConfig(
    ...     peft_type="PREFIX_TUNING",
    ...     task_type="SEQ_2_SEQ_LM",
    ...     num_virtual_tokens=20,
    ...     token_dim=768,
    ...     num_transformer_submodules=1,
    ...     num_attention_heads=12,
    ...     num_layers=12,
    ...     encoder_hidden_size=768,
    ... )
    >>> prefix_encoder = PrefixEncoder(config)
    ```

    **Attributes**:
        - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prefix encoder.
        - **transform** (`torch.nn.Sequential`) -- The two-layer MLP to transform the prefix embeddings if
          `prefix_projection` is `True`.
        - **prefix_projection** (`bool`) -- Whether to project the prefix embeddings.

    Input shape: (`batch_size`, `num_virtual_tokens`)

    Output shape: (`batch_size`, `num_virtual_tokens`, `2*layers*hidden`)
    """

    def __init__(self, config):
        super().__init__()
        self.prefix_projection = config.prefix_projection
        token_dim = config.token_dim
        num_layers = config.num_layers
        encoder_hidden_size = config.encoder_hidden_size
        num_virtual_tokens = config.num_virtual_tokens
        if self.prefix_projection and not config.inference_mode:
            # Use a two-layer MLP to encode the prefix
            self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim)
            self.transform = torch.nn.Sequential(
                torch.nn.Linear(token_dim, encoder_hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim),
            )
        else:
            self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim)

    def forward(self, prefix: torch.Tensor):
        if self.prefix_projection:
            prefix_tokens = self.embedding(prefix)
            past_key_values = self.transform(prefix_tokens)
        else:
            past_key_values = self.embedding(prefix)
        return past_key_values'''

import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
        
class Swish(nn.Module):
    def __init__(self, beta=1.0):
        super(Swish, self).__init__()
        self.beta = beta

    def forward(self, x):
        return x * torch.sigmoid(self.beta * x)
    
class LeakyGERLU(nn.Module):
    def forward(self, x):
        return F.leaky_relu(F.gelu(x))
    
class SearchLayer(nn.Module):
    def __init__(self):
        super(SearchLayer, self).__init__()
        
        self.operations = nn.ModuleList([
            nn.Sequential(
                nn.Linear(1024, 1024),
                nn.ReLU()
            ),
            nn.Sequential(
                nn.Linear(1024, 1024),
                nn.LeakyReLU(negative_slope=0.01)
            ),
            nn.Sequential(
                nn.Linear(1024, 1024),
                nn.SiLU()
            ),
            nn.Sequential(
                nn.Linear(1024, 1024),
                nn.Tanh()
            ),
            nn.Sequential(
                nn.Linear(1024, 1024),
                nn.GELU()
            ),
            nn.Dropout(0.1),
            nn.LayerNorm(1024, 1024),

        ])

        self.operations[0].name = "ReLU"
        self.operations[1].name = "Leaky_ReLu"
        self.operations[2].name = "SiLU"
        self.operations[3].name = "Tanh"
        self.operations[4].name = "Gelu"
        self.operations[5].name = "Dropout_0_1"
        self.operations[-1].name = "LayerNorm"

        self.alphas = nn.Parameter(torch.randn(len(self.operations)),requires_grad=True)
        
        

    def forward(self, x, layer_idx):

        soft_alphas = F.softmax(self.alphas, dim=0)

            
        output = sum(soft_alphas[i] * self.operations[i](x) for i in range(len(self.operations)))

        best_op_index = torch.argmax(soft_alphas).item()
        best_op_name = self.operations[best_op_index].name
        
        with open('/lus/grand/projects/COMPASS-GLM/xud/anzir/log_s.txt', 'a') as f:
            f.write(f"\nLayer {layer_idx}: Best operation: {best_op_name} ({soft_alphas[best_op_index]:.4f})\n")


        return output
        



class PrefixEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.prefix_projection = config.prefix_projection
        token_dim = config.token_dim
        num_layers = config.num_layers
        num_virtual_tokens = config.num_virtual_tokens
        
        if self.prefix_projection:

            self.embedding = nn.Embedding(num_virtual_tokens, 1024)

            self.search_layers = nn.ModuleList([
                    SearchLayer(),
                    SearchLayer(),
                    SearchLayer(),
                    SearchLayer(),
                    SearchLayer(),
                    SearchLayer(),
            ])

            self.output = nn.Linear(1024, num_layers * 2 * token_dim)
        else:
            '''self.embedding = nn.Embedding(num_virtual_tokens, token_dim)
            #self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim)
            #self.fc1 = nn.Linear(1024, 1024)
            layers = []
            for _ in range(1):
                layers.append(torch.nn.LeakyReLU(negative_slope=0.01))
                layers.append(torch.nn.LeakyReLU(negative_slope=0.01))
                layers.append(torch.nn.Dropout(p=0.1))
                layers.append(torch.nn.GELU())
                layers.append(torch.nn.LeakyReLU(negative_slope=0.01))
                layers.append(torch.nn.Linear(1024, 1024))
                layers.append(torch.nn.GELU())

                
            self.transform = torch.nn.Sequential(
                *layers,
                torch.nn.Linear(1024, num_layers * 2 * token_dim),
            )'''
            self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim)
            

    def forward(self, prefix):
        if self.prefix_projection:
            prefix_tokens = self.embedding(prefix)
            #prefix_tokens = self.fc1(prefix_tokens)
            x = prefix_tokens

            # Pass through each search layer sequentially
            for idx, search_layer in enumerate(self.search_layers):
                # Pass x through the search layer and provide the layer index
                x = search_layer(x, idx)
                
            with open('/location/log_s.txt', 'a') as f:
                f.write("_____________________________________________\n")
                
            
            past_key_values = self.output(x)
        else:
            past_key_values = self.embedding(prefix)
            #past_key_values = self.fc1(prefix_tokens)
            #print(past_key_values.shape)
            #past_key_values = self.transform(prefix_tokens)

        return past_key_values