# encoding=utf-8

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F


class AttnWrapper(torch.nn.Module):
    def __init__(self, attn):
        super().__init__()
        self.attn = attn
        self.activations = None
        self.add_tensor = None
        self.present = None

    def forward(self, *args, **kwargs):
        output = self.attn(*args, **kwargs)
        if self.add_tensor is not None:
            output = (output[0] + self.add_tensor,)+output[1:]

        self.activations = output[0]
        self.present = output[1]
        return output

    def reset(self):
        self.activations = None
        self.add_tensor = None


class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block, unembed_matrix, norm, apply_residual_connection_post_layernorm, idx, config, changes=None):
        super().__init__()
        self.config = config
        self.changes = changes

        self.layer_idx = idx
        self.block = block                       # decoder
        self.unembed_matrix = unembed_matrix     # the final linear layer
        self.norm = norm                         # the final norm layer within the block

        # self.block.self_attention: the attention mechanism
        # self.block.post_attention_layernorm: the norm layer after attention
        self.block.self_attention = AttnWrapper(self.block.self_attention)
        self.post_attention_layernorm = self.block.post_attention_layernorm

        self.dense_h_to_4h = self.block.mlp.dense_h_to_4h
        self.gelu_impl = self.block.mlp.gelu_impl
        self.dense_4h_to_h = self.block.mlp.dense_4h_to_h
        self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm

        self.block_output = None
        self.block_output_unembedded = None
        self.mlp_output_2 = None
        self.mlp_output_unembedded = None
        self.intermediate_res_unembedded = None

        self.attn_mech_output = None
        self.mlp_output = []

        self.now_token = 0


    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        self.block_output = output[0]
        # self.block_output_unembedded = self.unembed_matrix(self.norm(output[0]))

        attn_output = self.block.self_attention.activations
        present = self.block.self_attention.present
        # attn_output += args[0]
        self.attn_mech_output = attn_output
        
        # self.intermediate_res_unembedded = self.unembed_matrix(self.norm(attn_output))

        if self.apply_residual_connection_post_layernorm:
            mlp_output = self.dense_4h_to_h(self.gelu_impl(self.dense_h_to_4h(self.post_attention_layernorm(attn_output)))) # without post-mlp residual
        else:
            mlp_output = self.dense_4h_to_h(self.gelu_impl(self.dense_h_to_4h(attn_output))) # without post-mlp residual

        mlp_output = self.dense_4h_to_h(self.gelu_impl(self.dense_h_to_4h(self.post_attention_layernorm(attn_output))))
        self.mlp_output_2 = mlp_output
        if self.changes and self.layer_idx in list(self.changes.keys()) and "mlp" in self.changes[self.layer_idx]:
            if self.now_token == 0:
                mlp_output[:, -1, :] = self.changes[self.layer_idx]["mlp"][self.now_token][0, -1, :]

        self.mlp_output.append(mlp_output) 

        if self.apply_residual_connection_post_layernorm:
            outputs = (self.post_attention_layernorm(attn_output) + mlp_output, )
        else:
            outputs = (attn_output + mlp_output, )

        if kwargs["use_cache"]:
            outputs += (present,)

        self.now_token += 1

        return outputs

    def attn_add_tensor(self, tensor):
        # set add_tensor
        self.block.self_attention.add_tensor = tensor

    def reset(self):
        self.block.self_attention.reset()


class Bloom7BHelper:
    def __init__(self, model_name, load_kwargs, changes=None):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
        # print(self.model)
        self.config = self.model.config
        self.changes = changes
        self.apply_residual_connection_post_layernorm = self.model.config.apply_residual_connection_post_layernorm
        self.embedding_layer = self.model.transformer.word_embeddings
        for i, layer in enumerate(self.model.transformer.h):
            # self.model.transformer.h: decoder * num_layers
            # self.model.lm_head: the final linear layer
            # self.model.transformer.ln_f: the final norm layer within the block
            self.model.transformer.h[i] = BlockOutputWrapper(layer, self.model.lm_head, self.model.transformer.ln_f, self.model.config.apply_residual_connection_post_layernorm, i, self.config, self.changes)

    def generate_text(self, prompt, max_new_tokens=500):
        """
        Direct generation
        """
        inputs = self.tokenizer.batch_encode_plus(prompt, padding=True, return_tensors="pt")
        # print(inputs)
        inputs = {k: inputs[k] for k in inputs if k in ["input_ids", "attention_mask"]}
        for t in inputs:
            if torch.is_tensor(inputs[t]):
                inputs[t] = inputs[t].to('cuda')
        # print(inputs)
        # print(inputs.input_ids.shape)
        generate_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens, pad_token_id=self.tokenizer.pad_token_id, do_sample=False)
        return self.tokenizer.batch_decode(generate_ids[:, inputs['input_ids'].shape[-1]:], skip_special_tokens=True, clean_up_tokenization_spaces=False)
        # inputs = self.tokenizer(prompt, return_tensors="pt")
        # generate_ids = self.model.generate(inputs.input_ids.to(self.device), max_new_tokens=max_new_tokens)
        # return self.tokenizer.batch_decode(generate_ids[:, inputs['input_ids'].shape[-1]:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    def get_logits(self, prompt):
        """
        logits = last_hidden_state @ self.model.lm_head
        """
        inputs = self.tokenizer(prompt, return_tensors="pt")
        with torch.no_grad():
            logits = self.model(inputs.input_ids.to(self.device), output_hidden_states=True).logits
            return logits

    def set_add_attn_output(self, layer, add_output):
        self.model.transformer.h[layer].attn_add_tensor(add_output)

    def reset_all(self):
        for layer in self.model.transformer.h:
            layer.reset()

    def reset_changes(self, changes):
        for layer in self.model.transformer.h:
            layer.changes = changes

    def reset_mlp(self):
        for layer in self.model.transformer.h:
            layer.mlp_output = []

    def reset_now_token(self):
        for layer in self.model.transformer.h:
            layer.now_token = 0

    def print_decoded_activations(self, decoded_activations, label, topk=10):
        softmaxed = torch.nn.functional.softmax(decoded_activations[0][-1], dim=-1)
        values, indices = torch.topk(softmaxed, topk)
        probs_percent = [v for v in values.tolist()]
        tokens = self.tokenizer.batch_decode(indices.unsqueeze(-1))
        return list(zip(tokens, probs_percent, indices))
    
    def get_top_tokens(self, text, topk=10):
        top_tokens = []
        self.get_logits(text)
        for i, layer in enumerate(self.model.transformer.h):
            temp = self.print_decoded_activations(layer.block_output_unembedded, 'Block output', topk=topk)[0:1]
            for token, prob, indices in temp:
                top_tokens.append(token)
        return top_tokens

    def get_activation(self, layer_idx_list, loc):
        activations = {}
        for layer_idx in layer_idx_list:
            layer = self.model.transformer.h[layer_idx]
            if loc == "attn":
                activations[layer_idx] = {"attn": layer.attn_mech_output}
            elif loc == "mlp":
                activations[layer_idx] = {"mlp": layer.mlp_output}
            elif loc == "block":
                activations[layer_idx] = {"block": layer.block_output}
        return activations

    def get_activation_unembedded(self, layer_idx_list, loc):
        activations = {}
        for layer_idx in layer_idx_list:
            layer = self.model.transformer.h[layer_idx]
            if loc == "mlp":
                activations[layer_idx] = {"mlp": layer.mlp_output_unembedded}
            elif loc == "block":
                activations[layer_idx] = {"block": layer.block_output_unembedded}
            elif loc == "inter":
                activations[layer_idx] = {"inter": layer.intermediate_res_unembedded}
        return activations
