import torch
import numpy as np
from src.models.model_base import ProbingLM
from transformers import GPTJForCausalLM, AutoTokenizer


class ProbingGPTJ(ProbingLM):
    def __init__(self, low_resource_mode, device):
        super().__init__()

        if low_resource_mode:
            print("using low_resource_mode and fp16")
            self.model = GPTJForCausalLM.from_pretrained(
                "EleutherAI/gpt-j-6B", revision="float16",
                torch_dtype=torch.float16, low_cpu_mem_usage=True,
            ).to(device)
        else:
            self.model = GPTJForCausalLM.from_pretrained(
                "EleutherAI/gpt-j-6B",
            ).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        self.device = device
        self.low_resource_mode = low_resource_mode

    def get_features(self, icl_datum, return_keys=None):
        icl_prompt = icl_datum["icl_prompt"]
        tokenized = self.tokenizer(icl_prompt, return_tensors="pt")
        input_ids = tokenized["input_ids"].to(self.device)
        attention_mask = tokenized["attention_mask"].to(self.device)
        with torch.no_grad():
            outputs = self.model.forward(
                input_ids, attention_mask=attention_mask,
                output_attentions=True, output_hidden_states=True)
            cpu = torch.device("cpu")
        prediction = self.tokenizer.convert_ids_to_tokens(
            [outputs.logits[0, -1].argmax()])[0]

        my_outputs = {
            "logits": outputs.logits[0].to(cpu).numpy(),
            "past_keys": torch.stack([
                _layer[0][0].to(cpu)
                for _layer in outputs.past_key_values
            ]),
            "past_values": torch.stack([
                _layer[1][0].to(cpu)
                for _layer in outputs.past_key_values
            ]),
            "attentions": np.stack([
                _layer[0, :, -1].to(cpu).numpy()
                for _layer in outputs.attentions
            ]),
            "hidden_states": torch.stack([
                _layer[0].to(cpu)
                for _layer in outputs.hidden_states
            ]),
            "last_hidden_states": outputs.hidden_states[-1][0].to(cpu),
            "input_ids": input_ids.to(cpu).numpy(),
            "attention_mask": attention_mask.to(cpu).numpy(),
            "prediction": prediction
        }
        if return_keys is not None:
            my_outputs = {
                _key: _value
                for _key, _value in my_outputs.items()
                if _key in return_keys
            }

        return my_outputs
