import torch
import torch.nn
import torch.nn as nn
import torch.nn.functional as F


class PrefixEncoder(torch.nn.Module):
    r"""
    The torch.nn model to encode the prefix
    Input shape: (batch-size, prefix-length)
    Output shape: (batch-size, prefix-length, 2*layers*hidden)
    """

    def __init__(self, config):
        super().__init__()
        self.embedding = torch.nn.Embedding(
            config.prompt_length, config.num_hidden_layers * 2 * config.hidden_size
        )

    def forward(self, prefix: torch.Tensor):
        return self.embedding(prefix)


class ZeroFC(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.fc = nn.Sequential(
            nn.Dropout(config.hidden_dropout_prob),
            nn.Linear(config.hidden_size, config.prompt_rank),
            nn.GELU(),
            nn.Linear(config.prompt_rank, config.hidden_size),
        )

    def forward(self, input, input_ids, input_embeddings):
        return input + self.fc(input_embeddings)


class Zero(nn.Module):
    def __init__(self, config):
        super().__init__()

        if config.model_type == "roberta":
            self.a = nn.Parameter(torch.zeros(256, config.prompt_rank))
            self.b = nn.Parameter(torch.zeros(200, config.prompt_rank))
        else:
            self.a = nn.Parameter(torch.zeros(360, config.prompt_rank))
            self.b = nn.Parameter(torch.zeros(360, config.prompt_rank))
        self.c = nn.Parameter(torch.zeros(config.hidden_size, config.prompt_rank**2))
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input, input_ids, input_embeddings):
        if self.training or True:
            embedding = self.fast_kron_embed(input_ids, self.a, self.b)
            prompt = F.linear(embedding, self.c)
        else:
            prompt = F.embedding(input_ids, self.fused_weight)
        return input + self.dropout(prompt)

    @staticmethod
    def fast_kron_embed(indices, a, b):
        a1, a2 = a.size()
        b1, b2 = b.size()

        a_indices = torch.div(indices, a1, rounding_mode="floor")
        b_indices = indices % b1

        a_rows = F.embedding(a_indices, a)
        b_rows = F.embedding(b_indices, b)

        a_rows = a_rows.repeat_interleave(b2, dim=-1)
        b_rows = b_rows.repeat(1, 1, a2)
        return a_rows * b_rows

    @staticmethod
    def efficient_full(a, b, c, split_size):
        blocks = torch.split(a, split_size)
        return torch.cat([F.linear(torch.kron(block, b), c) for block in blocks], 0)

    def train(self, mode=True):
        if mode:
            self.fused_weight = None
        else:
            with torch.no_grad():
                self.fused_weight = None

        if not isinstance(mode, bool):
            raise ValueError("training mode is expected to be boolean")
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self


class ClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        classifier_dropout = (
            config.classifier_dropout
            if config.classifier_dropout is not None
            else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, index=0, **kwargs):
        x = features[:, index, :]  # take <s> token (equiv. to [CLS])
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


def efficient_full(a, b, c, split_size):
    blocks = torch.split(a, split_size)
    return torch.cat([F.linear(torch.kron(block, b), c) for block in blocks], 0)
