import torch
import torch.nn as nn
import torch.nn.functional as F

class Unit(nn.Module):
    def __init__(self, model, soft_prompt_dim, deep_prompt_dim=0, dropout_prob=0.0, lm_embed_init=False):
        super().__init__()
        self.lm_config = model.config
        if lm_embed_init:
            with torch.no_grad():
                init_indexes = torch.randint(self.lm_config.vocab_size, (soft_prompt_dim,))
                init_prompt_value = model._get_wte()(init_indexes).clone()
        else:
            init_prompt_value = torch.FloatTensor(soft_prompt_dim, self.lm_config.n_embd).uniform_(-0.5, 0.5)
        self.soft_prompt = nn.parameter.Parameter(init_prompt_value)

        self.deep_prompt_dim = deep_prompt_dim
        if self.deep_prompt_dim > 0:
            n_hidden = self.lm_config.hidden_size if 'hidden_size' in self.lm_config.to_dict() else self.lm_config.n_embd
            init_deep_value = torch.FloatTensor(deep_prompt_dim, self.lm_config.n_layer * 2, self.lm_config.n_head, n_hidden // self.lm_config.n_head).uniform_(-0.5, 0.5)
            self.deep_prompt = nn.parameter.Parameter(init_deep_value)

        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, batch_size):
        prompt = self.soft_prompt.repeat(batch_size, 1, 1)
        prompt = self.dropout(prompt)
        prompt_mask = torch.full((prompt.shape[0], prompt.shape[1]), 1, device=prompt.device)

        if self.deep_prompt_dim > 0:
            deep_prompt = self.deep_prompt.repeat(batch_size, 1, 1, 1, 1)
            deep_prompt = self.dropout(deep_prompt)
            past_key_values = deep_prompt.permute([2, 0, 3, 1, 4]).split(2)# borrowed from P-Tuning V2
            past_key_mask = torch.full((batch_size, self.deep_prompt_dim), 1, device=deep_prompt.device)
        else:
            past_key_values = None
            past_key_mask = None

        return {
            'input_embeds': prompt,
            'attention_mask': prompt_mask,
            'past_key_values': past_key_values,
            'past_key_mask': past_key_mask
        }

class MTUnit(nn.Module):
    def __init__(self, num_tasks, model, soft_prompt_dim, deep_prompt_dim=0, dropout_prob=0.0, lm_embed_init=False):
        super().__init__()
        self.lm_config = model.config
        self.num_tasks = num_tasks

        if lm_embed_init:
            with torch.no_grad():
                init_indexes = torch.randint(self.lm_config.vocab_size, (num_tasks, soft_prompt_dim))
                init_prompt_value = model._get_wte()(init_indexes).clone()
        else:
            init_prompt_value = torch.FloatTensor(num_tasks, soft_prompt_dim, self.lm_config.n_embd).uniform_(-0.5, 0.5)
        self.soft_prompt = nn.parameter.Parameter(init_prompt_value)

        self.deep_prompt_dim = deep_prompt_dim
        if self.deep_prompt_dim > 0:
            n_hidden = self.lm_config.hidden_size if 'hidden_size' in self.lm_config.to_dict() else self.lm_config.n_embd
            init_deep_value = torch.FloatTensor(num_tasks, deep_prompt_dim, self.lm_config.n_layer * 2, self.lm_config.n_head, n_hidden // self.lm_config.n_head).uniform_(-0.5, 0.5)
            self.deep_prompt = nn.parameter.Parameter(init_deep_value)

        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, task_ids):
        prompt = self.soft_prompt[task_ids]
        prompt = self.dropout(prompt)
        prompt_mask = torch.full((prompt.shape[0], prompt.shape[1]), 1, device=prompt.device)

        if self.deep_prompt_dim > 0:
            deep_prompt = self.deep_prompt[task_ids]
            deep_prompt = self.dropout(deep_prompt)
            past_key_values = deep_prompt.permute([2, 0, 3, 1, 4]).split(2)# borrowed from P-Tuning V2
            past_key_mask = torch.full((len(task_ids), self.deep_prompt_dim), 1, device=deep_prompt.device)
        else:
            past_key_values = None
            past_key_mask = None

        return {
            'input_embeds': prompt,
            'attention_mask': prompt_mask,
            'past_key_values': past_key_values,
            'past_key_mask': past_key_mask
        }

class Projection(nn.Module):
    def __init__(self, model, hidden_dim, soft_prompt_dim, deep_prompt_dim, n_layers=1, shared_proj=False, dropout_prob=0.0):
        super().__init__()
        self.lm_config = model.config
        self.hidden_dim = hidden_dim

        self.soft_tokens = soft_prompt_dim
        self.soft_prompt_dim = soft_prompt_dim
        self.deep_tokens = deep_prompt_dim * self.lm_config.n_layer * 2
        self.deep_prompt_dim = deep_prompt_dim

        if shared_proj:
            self.soft_prompt_params = self.soft_tokens * self.hidden_dim
            self.deep_prompt_params = self.deep_tokens * self.hidden_dim
        else:
            self.soft_prompt_params = self.soft_tokens * self.lm_config.n_embd
            n_hidden = self.lm_config.hidden_size if 'hidden_size' in self.lm_config.to_dict() else self.lm_config.n_embd
            self.deep_prompt_params = self.deep_tokens * n_hidden

        self.shared_proj = shared_proj

        layers = []
        for i in range(n_layers - 1):
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.Linear(hidden_dim, hidden_dim))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(hidden_dim, self.soft_prompt_params + self.deep_prompt_params))
        self.proj = nn.Sequential(*layers)

        if self.shared_proj:
            self.decode = nn.Sequential(
                nn.ReLU(),# prevent gradient from being overwritten XXX not sure why this is needed..
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(inplace=True),
                nn.Linear(hidden_dim, self.lm_config.n_embd)
            )

        self.dropout = nn.Dropout(dropout_prob)
    
    def forward(self, hidden):
        # hidden : batch_size x hidden_size
        hidden = self.proj(hidden)

        if self.soft_prompt_dim > 0:
            soft_prompt = hidden[:, :self.soft_prompt_params]
            if not self.shared_proj:
                soft_prompt = soft_prompt.view(-1, self.soft_tokens, self.lm_config.n_embd)
            else:
                soft_prompt = soft_prompt.view(-1, self.soft_tokens, self.hidden_dim)
                soft_prompt = self.decode(soft_prompt)
            soft_prompt = self.dropout(soft_prompt)
            soft_prompt_mask = torch.full((soft_prompt.shape[0], soft_prompt.shape[1]), 1, device=soft_prompt.device)
        else:
            soft_prompt = None
            soft_prompt_mask = None

        if self.deep_prompt_dim > 0:
            deep_prompt = hidden[:, self.soft_prompt_params:]
            n_hidden = self.lm_config.hidden_size if 'hidden_size' in self.lm_config.to_dict() else self.lm_config.n_embd
            if not self.shared_proj:
                deep_prompt = deep_prompt.view(-1, self.deep_prompt_dim, self.lm_config.n_layer * 2, self.lm_config.n_head, n_hidden // self.lm_config.n_head)
            else:
                deep_prompt = deep_prompt.view(-1, self.deep_prompt_dim, self.lm_config.n_layer * 2, self.hidden_dim)
                deep_prompt = self.decode(deep_prompt)
                deep_prompt = deep_prompt.view(-1, self.deep_prompt_dim, self.lm_config.n_layer * 2, self.lm_config.n_head, n_hidden // self.lm_config.n_head)
            deep_prompt = self.dropout(deep_prompt)
            past_key_values = deep_prompt.permute([2, 0, 3, 1, 4]).split(2)# borrowed from P-Tuning V2
            past_key_mask = torch.full((hidden.shape[0],  self.deep_prompt_dim), 1, device=deep_prompt.device)
        else:
            past_key_values = None
            past_key_mask = None

        return {
            'input_embeds': soft_prompt,
            'attention_mask': soft_prompt_mask,
            'past_key_values': past_key_values,
            'past_key_mask': past_key_mask
        }

class GoalEmbed(nn.Module):
    def __init__(self, model, goal_dim, soft_prompt_dim, deep_prompt_dim=0, goal_lm_embed=False, shared_proj=False, dropout_prob=0.0):
        super().__init__()
        self.lm_config = model.config
        self.goal_dim = goal_dim
        self.lm_embed_dim = self.lm_config.n_embd
        self.goal_lm_embed = goal_lm_embed

        self.hidden_dim = 32
        if self.goal_lm_embed:
            goal_enc_dim = model.config.n_embd
        else:
            self.goals_emb = nn.Embedding(goal_dim, 32)
            goal_enc_dim = 32

        self.goals_enc = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Linear(goal_enc_dim, self.hidden_dim),
        )
        n_layers = 1 if shared_proj else 2
        self.proj = Projection(model, self.hidden_dim, soft_prompt_dim, deep_prompt_dim, n_layers, shared_proj, dropout_prob)

    def forward(self, goals_embeds=None, goals_mask=None, goals_compact=None):
        if self.goal_lm_embed:
            goals_hidden = goals_embeds
        else:
            goals_hidden = self.goals_emb(goals_compact)

        goals_hidden = torch.max(goals_hidden, dim=1).values
        hidden = self.goals_enc(goals_hidden)

        return self.proj(hidden)
