import torch
import torch.nn as nn


class Projected_Adaptor(nn.Module):
    def __init__(self, lm_head, adaptor_class, num_switches, embed_dim,
                 vocab_size, rank, epsilon, init_var, position="output"):
        super().__init__()
        assert rank > 0
        if adaptor_class == "multiply":
            self.projector1 = nn.Parameter(torch.randn(
                num_switches, embed_dim, rank
            ) * init_var)
            self.projector2 = nn.Parameter(torch.randn(
                num_switches, embed_dim, rank
            ) * init_var)
        elif adaptor_class == "add":
            self.add_vec = nn.Parameter(torch.randn(
                num_switches, embed_dim
            ))
        elif adaptor_class == "offset":
            self.offset_vec = nn.Parameter(torch.randn(
                num_switches, vocab_size
            ))
        else:
            raise NotImplementedError()

        self.adaptor_class = adaptor_class
        self.rank = rank
        self.lm_head = lm_head
        self.epsilon = epsilon
        self.position = position
        self.num_switches = num_switches
        self.init_var = init_var
        self.switch_values = torch.zeros(num_switches)

    def set_value(self, switch_values):
        self.switch_values = switch_values

    def forward(self, state):
        if self.switch_values.abs().sum() == 0:
            return state.matmul(
                self.lm_head.weight.detach().transpose(0, 1))
        if self.adaptor_class == "multiply":
            delta = state[:, None].matmul(self.projector1[None]) *\
                self.switch_values[:, :, None, None]
            delta = delta.matmul(
                self.projector2.transpose(1, 2)[None]).sum(1)
            projected_state = state + self.epsilon * delta
            logits = projected_state.matmul(
                self.lm_head.weight.detach().transpose(0, 1))
        elif self.adaptor_class == "add":
            add_values = self.switch_values.matmul(self.add_vec)
            projected_state = state + self.epsilon * add_values[:, None]
            logits = projected_state.matmul(
                self.lm_head.weight.detach().transpose(0, 1))
        elif self.adaptor_class == "offset":
            offset_values = self.switch_values.matmul(self.offset_vec)
            logits = state.matmul(
                self.lm_head.weight.detach().transpose(0, 1))
            logits = logits + self.epsilon * offset_values[:, None]
        return logits

    def regularization_term(self):
        # if self.rank <= 0:
        #     return self.projector.pow(2).sum()
        # else:
        if self.adaptor_class == "multiply":
            return self.projector1.pow(2).sum() + self.projector2.pow(2).sum()
        elif self.adaptor_class == "add":
            return self.add_vec.pow(2).sum()
        elif self.adaptor_class == "offset":
            return self.offset_vec.pow(2).sum()

    def parameters(self):
        if self.adaptor_class == "multiply":
            return [self.projector1, self.projector2]
        elif self.adaptor_class == "add":
            return [self.add_vec]
        elif self.adaptor_class == "offset":
            return [self.offset_vec]

    def state_dict(self):
        if self.adaptor_class == "multiply":
            return {"projector1": self.projector1,
                    "projector2": self.projector2}
        elif self.adaptor_class == "add":
            return {"add_vec": self.add_vec}
        elif self.adaptor_class == "offset":
            return {"offset_vec": self.offset_vec}

    def load_state_dict(self, state_dict):
        if self.adaptor_class == "multiply":
            self.projector1.data = state_dict["projector1"]
            self.projector2.data = state_dict["projector2"]
        elif self.adaptor_class == "add":
            self.add_vec.data = state_dict["add_vec"]
        elif self.adaptor_class == "offset":
            self.offset_vec.data = state_dict["offset_vec"]
